From 96ef08c891f102d131584ccd3707c3aa167f52c7 Mon Sep 17 00:00:00 2001 From: Noah Syrkis Date: Fri, 15 Dec 2023 08:19:39 -0300 Subject: [PATCH] elaborate workaround apple not fixing metal gpu --- config.yaml | 9 +- notebook.ipynb | 720 ++++++++++++++++++------------------------------- src/data.py | 2 +- src/fmri.py | 22 +- src/graph.py | 26 ++ 5 files changed, 316 insertions(+), 463 deletions(-) create mode 100644 src/graph.py diff --git a/config.yaml b/config.yaml index 236a7f3..8a674dc 100644 --- a/config.yaml +++ b/config.yaml @@ -2,14 +2,15 @@ neuroscope: source : img # fmri or img image_size : 64 # only applicable to COCO (neuroscope) images - embed_dim : 256 # bottleneck size (latent space) + latent_dim : 128 # bottleneck size (latent space) batch_size : 64 - in_chans : 3 # number of input channels - chan_start : 24 # number of output channels for first conv layer + in_chans : 1 # number of input channels + chan_start : 32 # number of output channels for first conv layer conv_branch : 2 # number of branches for conv layers kernel_size : 3 # 3 or 5 stride : 2 # 1 or 2 or 4 - conv_layers : 4 # could be quite deep i think + gcn_layers : 5 # 2 or 3 + conv_layers : 5 # could be quite deep i think fc_layers : 1 # should probably be 1 lr : 0.001 # seems 0.001 is best even with batch norm epochs : 50 # between 10 and 50 diff --git a/notebook.ipynb b/notebook.ipynb index 90abcdf..bb0f72d 100644 --- a/notebook.ipynb +++ b/notebook.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -19,7 +19,9 @@ "from jax import vmap, jit, lax, random, grad, value_and_grad\n", "import jax.numpy as jnp\n", "import optax\n", + "from jax import config\n", "import jraph\n", + "import numpy as np\n", "\n", "import wandb\n", "import numpy as np\n", @@ -28,22 +30,24 @@ "import time\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", + "import os, pickle\n", "\n", "import syrkis\n", "from src.data import load_subjects, make_kfolds\n", - "from src.fmri import get_bold_with_coords_and_faces" + "from src.fmri import get_bold_with_coords_and_faces as get_mesh" ] }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "# GLOBALS\n", "rng = random.PRNGKey(0)\n", "cfg = syrkis.train.load_config()['neuroscope']\n", - "opt = optax.adamw(learning_rate=cfg['lr'])" + "opt = optax.adamw(learning_rate=cfg['lr'])\n", + "config.update(\"jax_platform_name\", \"cpu\")" ] }, { @@ -55,15 +59,11 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ - "subjects = load_subjects(['subj05', 'subj07'], cfg['image_size'])\n", - "n_samples = sum([len(s[0]) for s in subjects.values()])\n", - "kfolds = make_kfolds(subjects, cfg)\n", - "train_batches, eval_batches = next(kfolds)\n", - "lh_sample, rh_sample, img_sample, subject_idx_sample = next(train_batches)" + "subjects = load_subjects(['subj05', 'subj07'], cfg['image_size'])" ] }, { @@ -75,37 +75,111 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 110, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(379470,\n", - " ['umbrella', 'chair', 'chair', 'chair', 'dining table', 'chair'],\n", - " ['A patio table surrounded by chairs on a patio.',\n", - " 'a wooden table with 4 chairs and an umbrella\\n',\n", - " 'wooden patio set with umbrella in a backyard',\n", - " 'A table with umbrella and chairs sits on a patio.',\n", - " 'Patio table with umbrella surrounded by chairs in the back yeard'])" - ] - }, - "execution_count": 129, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "lh = subjects['subj05'][0][0]\n", - "coords, features, faces = get_bold_with_coords_and_faces(lh, 'subj05', 'lh')\n", - "senders = faces[:, 0] + faces[:, 1] + faces[:, 2]\n", - "receivers = faces[:, 1] + faces[:, 2] + faces[:, 0]\n", - "n_node, n_edge = jnp.array([features.shape[0]]), jnp.array([senders.shape[0]])\n", - "graph = jraph.GraphsTuple(\n", - " n_node=n_node, n_edge=n_edge, edges=None, globals=None,\n", - " nodes=features, senders=senders, receivers=receivers)\n", - "meta = subjects['subj05'][-1]\n", - "meta[4]\n" + "def make_samples(sample, subject, n_subjects):\n", + " lhs, rhs, imgs = sample\n", + " fmri = []\n", + " for idx, lh, rh in tqdm(zip(range(n_subjects), lhs, rhs)):\n", + " lh_graph = make_graph(*get_mesh(lh, subject, 'lh'))\n", + " rh_graph = make_graph(*get_mesh(rh, subject, 'rh'))\n", + " graph = combine_hems(lh_graph, rh_graph)\n", + " fmri.append(graph)\n", + " # fmri = jnp.array(fmri)\n", + " return fmri, imgs\n", + "\n", + "def combine_hems(lh_graph, rh_graph):\n", + " # Concatenate node feature\n", + " # number of 0s we need to add to have nodes be power of 2\n", + " n_node = lh_graph.n_node + rh_graph.n_node\n", + " padding_size = (2 ** jnp.ceil(jnp.log2(n_node)) - n_node).astype(int)[0]\n", + " padding = jnp.zeros((padding_size, 1))\n", + " nodes = jnp.concatenate([lh_graph.nodes, rh_graph.nodes, padding], axis=0)\n", + "\n", + " # Adjust senders and receivers indices for right hemisphere\n", + " rh_offset = lh_graph.n_node\n", + " rh_senders = rh_graph.senders + rh_offset\n", + " rh_receivers = rh_graph.receivers + rh_offset\n", + "\n", + " senders = jnp.concatenate([lh_graph.senders, rh_senders], axis=0)\n", + " receivers = jnp.concatenate([lh_graph.receivers, rh_receivers], axis=0)\n", + "\n", + " n_node += + padding_size\n", + " n_edge = lh_graph.n_edge + rh_graph.n_edge\n", + " \n", + " return jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,\n", + " nodes=nodes, senders=senders, receivers=receivers)\n", + "\n", + "\n", + "@jit\n", + "def make_graph(coords, features, faces):\n", + " senders = jnp.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0)\n", + " receivers = jnp.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0)\n", + " n_node = jnp.array([features.shape[0]])\n", + " n_edge = jnp.array([senders.shape[0]])\n", + "\n", + " graph = jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,\n", + " nodes=features[:, None], senders=senders, receivers=receivers)\n", + " return graph" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [], + "source": [ + "def update_node_fn(node_features, params):\n", + " weights, biases = params\n", + " return jnp.dot(node_features, weights) + biases\n", + "\n", + "def apply_graph_convolution(graph, params):\n", + " # Define the graph convolution layer\n", + " gcn_layer = jraph.GraphConvolution(\n", + " update_node_fn=partial(update_node_fn, params=params),\n", + " aggregate_nodes_fn=jraph.segment_sum,\n", + " add_self_edges=True,\n", + " symmetric_normalization=True,\n", + " )\n", + "\n", + " # Apply the graph convolution layer to the graph\n", + " return gcn_layer(graph)" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def pool_fn(graph, pool_size):\n", + " # Reshape and pool node features\n", + " num_nodes, num_features = graph.nodes.shape\n", + " pooled_features = graph.nodes.reshape(-1, pool_size, num_features)\n", + " pooled_features = jnp.mean(pooled_features, axis=1)\n", + "\n", + " # Update edges for the pooled graph\n", + " # Create a mapping from old node indices to new pooled node indices\n", + " node_mapping = np.repeat(np.arange(len(pooled_features)), pool_size)[:num_nodes]\n", + "\n", + " # Update senders and receivers based on the node mapping\n", + " pooled_senders = node_mapping[graph.senders]\n", + " pooled_receivers = node_mapping[graph.receivers]\n", + "\n", + " # Filter out self-loops created by pooling\n", + " edge_mask = pooled_senders != pooled_receivers\n", + " pooled_senders = pooled_senders[edge_mask]\n", + " pooled_receivers = pooled_receivers[edge_mask]\n", + "\n", + " # Update the graph with pooled nodes and edges\n", + " pooled_graph = graph._replace(nodes=pooled_features, senders=pooled_senders, receivers=pooled_receivers)\n", + " return pooled_graph\n", + "\n", + "# Example usage\n", + "# Assume 'graph' is your input jraph.GraphsTuple" ] }, { @@ -117,15 +191,15 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 113, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "4\n", - "3072\n" + "2\n", + "2048\n" ] } ], @@ -133,7 +207,6 @@ "def latent_side_fn(cfg):\n", " return cfg['image_size'] // cfg['stride'] ** cfg['conv_layers']\n", "\n", - "\n", "def latent_dim_fn(cfg):\n", " # should return the size of the loatente dim depending on initial image size, stride, and number of layers, and channels\n", " channels = cfg['chan_start']\n", @@ -145,8 +218,6 @@ " latent_dim = latent_channels * latent_side ** 2\n", " return latent_dim\n", " \n", - "\n", - "\n", "latent_dim = latent_dim_fn(cfg)\n", "latent_side = latent_side_fn(cfg)\n", "print(latent_side)\n", @@ -162,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -195,17 +266,9 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 115, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[-1.9884775 0.638638 1.3899286 -0.11827302]\n" - ] - } - ], + "outputs": [], "source": [ "def init_linear_layer(rng, in_dim, out_dim, tensor_dim):\n", " # tensor dim is for having fmri embedding in same array, but seperate layers.\n", @@ -219,34 +282,12 @@ " gamma, beta = init_batch_norm(b_shape)\n", " return w, b, gamma, beta\n", "\n", - "def init_linear_layers(rng, in_dim, out_dim, cfg, tensor_dim=0):\n", - " # first layer goes from in_dim to embed_dim, rest are embed to embed, and last is embed to out\n", - " rngs = jax.random.split(rng, cfg['fc_layers'])\n", - " params = []\n", - " for idx, rng in enumerate(rngs):\n", - " layer_in_dim = cfg['embed_dim'] if idx != 0 else in_dim\n", - " layer_out_dim = cfg['embed_dim'] if idx != cfg['fc_layers'] - 1 else out_dim\n", - " params.append(init_linear_layer(rng, layer_in_dim, layer_out_dim, tensor_dim))\n", - " return params\n", - "\n", "def linear(params, x):\n", " for idx, (w, b, gamma, beta) in enumerate(params):\n", " x = x @ w + b\n", " x = jax.nn.gelu(x) if idx != len(params) - 1 else x\n", " x = batch_norm(x, gamma, beta) if idx != len(params) - 1 else x\n", - " return x\n", - "\n", - "def test_linear():\n", - " cfg = syrkis.train.load_config()\n", - " cfg = cfg[dataset]\n", - " rng = jax.random.PRNGKey(0)\n", - " x = jnp.array([1.0, 2.0])\n", - " params = init_linear_layers(rng, 2, 4, cfg)\n", - " y = linear(params, x)\n", - " assert y.shape == (4,)\n", - " print(y)\n", - "\n", - "test_linear()" + " return x" ] }, { @@ -258,32 +299,15 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 116, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(64, 4, 4, 192) (3, 3, 24, 48)\n" - ] - } - ], + "outputs": [], "source": [ "# Global constants for common parameters\n", "DIMENSION_NUMBERS = (\"NHWC\", \"HWIO\", \"NHWC\")\n", "\n", "\n", "@jit\n", - "def conv2d(x, w):\n", - " return jax.lax.conv_general_dilated(\n", - " x, w, \n", - " window_strides=(cfg['stride'], cfg['stride']),\n", - " padding='SAME',\n", - " dimension_numbers=DIMENSION_NUMBERS)\n", - "\n", - "\n", - "@jit\n", "def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):\n", " # Assuming x has shape (batch, height, width, channels)\n", " b, h, w, c = x.shape\n", @@ -304,16 +328,45 @@ "\n", "def conv_fn(fn):\n", " def apply_fn(params, x):\n", - " for i, (w, b, gamma, beta) in enumerate(params):\n", + " for i, (w, b) in enumerate(params):\n", " x = fn(x, w, b)\n", - " x = batch_norm(x, gamma, beta) if i != len(params) - 1 else x\n", + " # x = batch_norm(x, gamma, beta) if i != len(params) - 1 else x\n", " # x = jax.nn.tanh(x) if i != len(params) - 1 else x\n", " x = jax.nn.gelu(x) if i != len(params) - 1 else x\n", " return x\n", " return apply_fn\n", "\n", "\n", - "conv = conv_fn(lambda x, w, b: conv2d(x, w) + b)\n", + "def manual_batch_graphs(graph_list):\n", + " # Initialize lists to hold the concatenated components\n", + " all_nodes = []\n", + " all_senders = []\n", + " all_receivers = []\n", + " offset = 0\n", + "\n", + " for graph in graph_list:\n", + " all_nodes.append(graph.nodes)\n", + " all_senders.append(graph.senders + offset)\n", + " all_receivers.append(graph.receivers + offset)\n", + " offset += graph.nodes.shape[0]\n", + "\n", + " # Concatenate all components\n", + " batched_nodes = jnp.concatenate(all_nodes, axis=0)\n", + " batched_senders = jnp.concatenate(all_senders, axis=0)\n", + " batched_receivers = jnp.concatenate(all_receivers, axis=0)\n", + "\n", + " # Create and return the combined GraphsTuple\n", + " return jraph.GraphsTuple(\n", + " n_node=jnp.array(batched_nodes.shape[0]),\n", + " n_edge=jnp.array(batched_senders.shape[0]),\n", + " nodes=batched_nodes,\n", + " senders=batched_senders,\n", + " receivers=batched_receivers,\n", + " edges=None, # or concatenate edges if your graph has them\n", + " globals=None # or concatenate globals if your graph has them\n", + " )\n", + "\n", + "\n", "deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)\n", "\n", "\n", @@ -325,7 +378,7 @@ " w = syrkis.train.glorot_init(key, w_shape)\n", " b = jnp.zeros((out_chan,))\n", " gamma, beta = init_batch_norm(b.shape)\n", - " return w, b, gamma, beta\n", + " return w, b # , gamma, beta\n", "\n", "\n", "def init_conv_layers(rng, cfg, deconv=False):\n", @@ -335,19 +388,7 @@ " in_chan = cfg['in_chans'] if idx == 0 else cfg['chan_start'] * (cfg['conv_branch'] ** (idx - 1))\n", " out_chan = cfg['chan_start'] * (cfg['conv_branch'] ** idx)\n", " params.append(init_conv_params(rng, in_chan, out_chan, cfg, deconv))\n", - " return params[::-1] if deconv else params\n", - "\n", - "\n", - "def test_conv():\n", - " cfg = syrkis.train.load_config()\n", - " cfg = cfg[dataset]\n", - " x = jnp.ones((cfg['batch_size'], cfg['image_size'], cfg['image_size'], cfg['in_chans']))\n", - " conv_params = init_conv_layers(jax.random.PRNGKey(0), cfg)\n", - " deconv_params = init_conv_layers(jax.random.PRNGKey(0), cfg, deconv=True)\n", - " z = conv(conv_params, x)\n", - " print(z.shape, conv_params[1][0].shape)\n", - "\n", - "test_conv()" + " return params[::-1] if deconv else params" ] }, { @@ -359,155 +400,81 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ - "def print_model(params):\n", - " print(f'{syrkis.train.n_params(params)} total params', end='\\n\\n')\n", - " print('\\tconv params')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['conv']):\n", - " print(f'\\t\\tconv_{idx}\\t\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['deconv']):\n", - " print(f'\\t\\tdeconv_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " print('\\n\\tlinear params')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['linear_encode']):\n", - " print(f'\\t\\tencode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['linear_decode']):\n", - " print(f'\\t\\tdecode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')" - ] - }, - { - "cell_type": "code", - "execution_count": 135, - "metadata": {}, - "outputs": [], - "source": [ - "\n", "def dropout(x, rate, rng):\n", " rate = 1.0 - rate\n", " keep = random.bernoulli(rng, rate, x.shape)\n", " return jnp.where(keep, x / rate, 0)\n", "\n", - "def encode_img_fn(params, img, rng=None):\n", - " z = jax.nn.tanh(conv(params['conv'], img))\n", - " z = dropout(z, cfg['dropout'], rng) if rng is not None else z\n", - " z = z.reshape(cfg['batch_size'], -1)\n", - " z = jax.nn.tanh(linear(params['linear_encode'], z))\n", - " mu, logvar = jnp.split(z, 2, axis=-1) if cfg['vae'] else (z, z)\n", - " return mu, logvar\n", "\n", "def matmul_slice(A, B_slice):\n", " return jnp.dot(A, B_slice)\n", "batched_matmul = vmap(matmul_slice, in_axes=(None, 2))\n", "\n", - "def encode_fmri_fn(params, fmri, subj):\n", - " embed_cube = params['fmri_embed'][0][0] # fmri_dim x embed_dim x n_subjects\n", - " z = batched_matmul(fmri, embed_cube) # n_subjects x batch_size x embed_dim\n", - " z = z.transpose((2, 0, 1)) # embed_dim x n_subjects x batch_size # prep for broadcast\n", - " one_hot = jax.nn.one_hot(subj, len(subjects)).T # n_subjects x batch_size # prep for broadcast\n", - " z = one_hot * z # embed_dim x n_subjects x batch_size # one subject dimension has zeros after broadcast\n", - " z = z.sum(axis=1).T # batch_size x embed_dim # sum over subjects\n", - " z = jax.nn.tanh(linear(params['fmri_dense'], z))\n", - " return z\n", "\n", "@jit\n", "def decode_fn(params, z, rng=None):\n", - " z = jax.nn.tanh(linear(params['linear_decode'], z))\n", - " z = dropout(z, cfg['dropout'], rng) if rng is not None else z\n", - " z = z.reshape(cfg['batch_size'], latent_side, latent_side, -1)\n", " z = deconv(params['deconv'], z)\n", " z = jax.nn.sigmoid(z)\n", " return z\n", "\n", - "def reparametrize(mu, logvar, rng):\n", - " std = jnp.exp(0.5 * logvar)\n", - " eps = random.normal(rng, mu.shape)\n", - " return mu + eps * std\n", "\n", - "def apply_fn(params, fmri, img, subj, rng=None, variational=False):\n", - " keys = jax.random.split(rng, 4) if rng is not None else (None, None, None, None)\n", - " mu, logvar = encode_img_fn(params, img, keys[0]) if cfg['source'] == 'img' else (None, None)\n", - " z = mu if not variational and rng is None else reparametrize(mu, logvar, keys[1])\n", - " z = dropout(z, cfg['dropout'], keys[2]) if rng is not None else z\n", - " x_hat = decode_fn(params, z, keys[3])\n", - " return x_hat, (mu, logvar)\n", + "def apply_fn(params, fmri):\n", + " z = fmri\n", "\n", - "apply_non_variational = jit(partial(apply_fn, variational=False))\n", + " # Apply graph convolution layers\n", + " for i, p in enumerate(params['gcn']):\n", + " z = apply_graph_convolution(z, p)\n", + " z = pool_fn(z, 4)\n", + " z = z.nodes.flatten()\n", + "\n", + " # Apply dense layer\n", + " for i, p in enumerate(params['fcs']):\n", + " z = jnp.dot(z, p[0]) + p[1]\n", + " z = jax.nn.relu(z)\n", + "\n", + " # Apply image deconv layers to make image\n", + " z = z.reshape(1, 2, 2, -1)\n", + " z = deconv(params['cnn'], z)\n", + " z = jax.nn.sigmoid(z)\n", + " return z\n", "\n", - "def kl_divergence(mu, logvar):\n", - " sigma = jnp.exp(0.5 * logvar)\n", - " return jnp.mean(-0.5 * jnp.sum(1 + logvar - mu ** 2 - sigma ** 2, axis=-1))\n", "\n", "\n", "# This function returns the total loss and its components (recon and KL losses).\n", - "def loss_and_components(params, fmri, img, subj, rng=None, variational=False):\n", - " x_hat, (mu, sigma) = apply_fn(params, fmri, img, subj, rng, variational)\n", - " recon_loss = jnp.mean(jnp.abs((img - x_hat))) if cfg['loss_fn'] == 'l1' else jnp.mean((img - x_hat) ** 2)\n", - " kl_loss = kl_divergence(mu, sigma) if rng is not None else 0\n", - " total_loss = recon_loss + kl_loss * cfg['beta'] if variational else recon_loss\n", - " return total_loss, (recon_loss, kl_loss)\n", - "\n", - "# This function only returns the total loss, which is needed for gradient computation.\n", - "def loss_fn(params, fmri, img, subj, rng=None, variational=False):\n", - " total_loss, _ = loss_and_components(params, fmri, img, subj, rng, variational)\n", - " return total_loss\n", - "\n", - "def update_fn(params, fmri, img, subj, opt_state, rng, variational):\n", + "def loss_fn(params, fmri, img):\n", + " img_hat = apply_fn(params, fmri)\n", + " recon_loss = jnp.mean((img - img_hat) ** 2)\n", + " return recon_loss\n", + "\n", + "def update_fn(params, fmri, img, subj, opt_state):\n", " # Get the loss, aux data (recon_loss, kl_loss), and gradients\n", - " (total_loss, (recon_loss, kl_loss)), grads = value_and_grad(loss_and_components, has_aux=True)(params, fmri, img, subj, rng, variational)\n", + " loss, grads = value_and_grad(loss_fn)(params, fmri, img, subj)\n", " updates, opt_state = opt.update(grads, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", - " return params, opt_state, total_loss, recon_loss, kl_loss\n", - "\n", - "update_train = jit(partial(update_fn, variational=cfg['vae']))\n", - "\n", - "\n", - "def init_fn(rng, cfg):\n", - "\n", - " if dataset == 'mnist' or cfg['source'] == 'img':\n", - " conv_params = init_conv_layers(rng, cfg)\n", - " linear_encode_params = init_linear_layers(rng, latent_dim, cfg['embed_dim'] * 2 if cfg['vae'] else cfg['embed_dim'], cfg) # linear layers\n", - " linear_decode_params = init_linear_layers(rng, cfg['embed_dim'], latent_dim, cfg) # linear layers\n", - " deconv_params = init_conv_layers(rng, cfg, deconv=True)\n", - " params = {\"conv\": conv_params, \"deconv\": deconv_params,\n", - " \"linear_encode\": linear_encode_params, \"linear_decode\": linear_decode_params}\n", - "\n", - " if dataset != 'mnist' and cfg['source'] == 'fmri':\n", - " fmri_embed_params = init_linear_layers(rng, 19004 * len(subjects), cfg['embed_dim'], cfg, tensor_dim=len(subjects)) # linear layers\n", - " fmri_dense_params = init_linear_layers(rng, cfg['embed_dim'], cfg['embed_dim'], cfg) # linear layers\n", - " linear_decode_params = init_linear_layers(rng, cfg['embed_dim'], latent_dim, cfg) # linear layers\n", - " deconv_params = init_conv_layers(rng, cfg, deconv=True)\n", - " params = {\"fmri_embed\": fmri_embed_params, \"fmri_dense\": fmri_dense_params,\n", - " \"deconv\": deconv_params, \"linear_decode\": linear_decode_params}\n", - "\n", - " return params\n", - "\n", - "\n", - "def print_model(params):\n", - " print(f'{syrkis.train.n_params(params)} total params', end='\\n\\n')\n", - " if dataset == 'mnist' or cfg['source'] == 'img':\n", - " print('\\tconv params')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['conv']):\n", - " print(f'\\t\\tconv_{idx}\\t\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['deconv']):\n", - " print(f'\\t\\tdeconv_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " print('\\n\\tlinear params')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['linear_encode']):\n", - " print(f'\\t\\tencode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['linear_decode']):\n", - " print(f'\\t\\tdecode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " if dataset != 'mnist' and cfg['source'] == 'fmri':\n", - " print('\\tfmri params')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['fmri_embed']):\n", - " print(f'\\t\\tfmri_embed_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['fmri_dense']):\n", - " print(f'\\t\\tfmri_dense_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['linear_decode']):\n", - " print(f'\\t\\tdecode_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - " for idx, (w, b, gamma, beta) in enumerate(params['deconv']):\n", - " print(f'\\t\\tdeconv_{idx}\\t: {w.size + b.size}\\t|\\t{str(w.shape).replace(\", \", \" x \")[1:-1]}')\n", - "\n" + " return params, opt_state, loss\n", + "\n", + "\n", + "def init_fn(rng, cfg, scale=1e-2):\n", + " gcn = []\n", + " c_in = 1\n", + " for i in range(cfg['gcn_layers']):\n", + " rng, key = random.split(rng)\n", + " c_out = 2 ** i \n", + " gcn.append((random.normal(key, (c_in, c_out)) * scale, jnp.zeros((c_out,))))\n", + " c_in = c_out\n", + " \n", + " rng, key = random.split(rng)\n", + " cnn = init_conv_layers(key, cfg, deconv=True)\n", + " rng, key = random.split(rng)\n", + " fcs = [\n", + " (random.normal(key, (1024, cfg['latent_dim'])) * scale, jnp.zeros((cfg['latent_dim'],))),\n", + " (random.normal(key, (cfg['latent_dim'], 2048)) * scale, jnp.zeros((2048,))),\n", + " ]\n", + " return {'gcn': gcn, 'fcs': fcs, 'cnn': cnn}\n" ] }, { @@ -519,248 +486,95 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 118, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "1234377 total params\n", - "\n", - "\tconv params\n", - "\t\tconv_0\t\t: 672\t|\t3 x 3 x 3 x 24\n", - "\t\tconv_1\t\t: 10416\t|\t3 x 3 x 24 x 48\n", - "\t\tconv_2\t\t: 41568\t|\t3 x 3 x 48 x 96\n", - "\t\tconv_3\t\t: 166080\t|\t3 x 3 x 96 x 192\n", - "\t\tdeconv_0\t: 165984\t|\t3 x 3 x 192 x 96\n", - "\t\tdeconv_1\t: 41520\t|\t3 x 3 x 96 x 48\n", - "\t\tdeconv_2\t: 10392\t|\t3 x 3 x 48 x 24\n", - "\t\tdeconv_3\t: 651\t|\t3 x 3 x 24 x 3\n", - "\n", - "\tlinear params\n", - "\t\tencode_0\t: 393344\t|\t3072 x 128\n", - "\t\tdecode_0\t: 396288\t|\t128 x 3072\n" - ] + "data": { + "text/plain": [ + "1963083" + ] + }, + "execution_count": 118, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "cfg = syrkis.train.load_config()['neuroscope']\n", "rng = jax.random.PRNGKey(0)\n", - "# params = init_fn(rng, cfg)\n", + "params = init_fn(rng, cfg)\n", "n_params = syrkis.train.n_params(params)\n", "opt_state = opt.init(params)\n", - "print_model(params)" + "syrkis.train.n_params(params)" ] }, { "cell_type": "code", - "execution_count": 137, + "execution_count": 121, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - " \n", - " \n", - "\n", - "\n", - "
\n", - " \n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - " \"Descriptive\n", - "
\n", - " \n", - "
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [00:10, 9.45it/s]\n" + ] + }, + { + "ename": "ValueError", + "evalue": "All input arrays must have the same shape.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[121], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m fmri, imgs \u001b[38;5;241m=\u001b[39m make_samples(subjects[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msubj05\u001b[39m\u001b[38;5;124m'\u001b[39m], \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msubj05\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m100\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m fmri \u001b[38;5;241m=\u001b[39m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfmri\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m paraply_fn \u001b[38;5;241m=\u001b[39m jit(vmap(apply_fn, in_axes\u001b[38;5;241m=\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m0\u001b[39m)))\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m2\u001b[39m)):\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2037\u001b[0m, in \u001b[0;36marray\u001b[0;34m(object, dtype, copy, order, ndmin)\u001b[0m\n\u001b[1;32m 2035\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mobject\u001b[39m, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 2036\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mobject\u001b[39m:\n\u001b[0;32m-> 2037\u001b[0m out \u001b[38;5;241m=\u001b[39m stack(\u001b[43m[\u001b[49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43melt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43melt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mobject\u001b[39;49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 2038\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2039\u001b[0m out \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([], dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2037\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 2035\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mobject\u001b[39m, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 2036\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mobject\u001b[39m:\n\u001b[0;32m-> 2037\u001b[0m out \u001b[38;5;241m=\u001b[39m stack([\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43melt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m elt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mobject\u001b[39m])\n\u001b[1;32m 2038\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2039\u001b[0m out \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([], dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2070\u001b[0m, in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 2068\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2069\u001b[0m dtype \u001b[38;5;241m=\u001b[39m dtypes\u001b[38;5;241m.\u001b[39mcanonicalize_dtype(dtype, allow_opaque_dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m-> 2070\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2037\u001b[0m, in \u001b[0;36marray\u001b[0;34m(object, dtype, copy, order, ndmin)\u001b[0m\n\u001b[1;32m 2035\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mobject\u001b[39m, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 2036\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mobject\u001b[39m:\n\u001b[0;32m-> 2037\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43melt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43melt\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mobject\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2038\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2039\u001b[0m out \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray([], dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1778\u001b[0m, in \u001b[0;36mstack\u001b[0;34m(arrays, axis, out, dtype)\u001b[0m\n\u001b[1;32m 1776\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m arrays:\n\u001b[1;32m 1777\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m shape(a) \u001b[38;5;241m!=\u001b[39m shape0:\n\u001b[0;32m-> 1778\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAll input arrays must have the same shape.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1779\u001b[0m new_arrays\u001b[38;5;241m.\u001b[39mappend(expand_dims(a, axis))\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m concatenate(new_arrays, axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype)\n", + "\u001b[0;31mValueError\u001b[0m: All input arrays must have the same shape." + ] } ], "source": [ - "def evaluate(params, train_batches, eval_batches, eval_steps=5):\n", - " train_loss, eval_loss = 0, 0\n", - " for i in range(eval_steps):\n", - " _, _, train_img, _ = next(train_batches)\n", - " _, _, eval_img, _ = next(eval_batches)\n", - " train_loss += loss_fn(params, None, train_img, None) / eval_steps\n", - " eval_loss += loss_fn(params, None, eval_img, None) / eval_steps\n", - " return train_loss, eval_loss\n", - "\n", - "def static_info(cfg):\n", - " return [\n", - " f\"n_params : {n_params}\",\n", - " f\"dropout : {cfg['dropout']}\",\n", - " f\"lr : {cfg['lr']}\",\n", - " f\"loss : {cfg['loss_fn']}\",\n", - " f\"embed_dim : {cfg['embed_dim']}\",\n", - " f\"beta : {cfg['beta']}\",\n", - " f\"batch_size : {cfg['batch_size']}\",\n", - " f\"conv_layers : {cfg['conv_layers']}\",\n", - " ]\n", - "\n", - "def dynamic_info(total_loss, recon_loss, kl_loss, step, cfg):\n", - " return [\n", - " #f\"loss : {total_loss:.4f}\",\n", - " f\"recon : {recon_loss:.4f}\",\n", - " f\"kl : {kl_loss:.4f}\",\n", - " f\"step : {step + 1} / {cfg['epochs'] * n_samples // cfg['batch_size']}\",\n", - " ]\n", - " \n", - "def imgs_fn(params, train_eval_img, img_sample, z_seed):\n", - " \n", - " return [\n", - " apply_non_variational(params, train_eval_lh, train_eval_img, train_eval_subj)[0][:6], # recon train\n", - " apply_non_variational(params, None, img_sample, None)[0][:6], # recon eval\n", - " decode_fn(params, z_seed).reshape(-1, cfg['image_size'], cfg['image_size'], cfg['in_chans'])[:6], # recon latent\n", - " ]\n", - "\n", - "train_eval_lh, train_eval_rh, train_eval_img, train_eval_subj = next(train_batches)\n", - "z_seed = jax.random.normal(rng, (encode_img_fn(params, train_eval_img)[0].shape))\n", - "\n", - "def train(params, opt_state, cfg, train_batches, eval_batches, rng, eval_step=100, monitor=False):\n", - " if monitor:\n", - " wandb.init(project='neuroscope', entity=\"syrkis\", config=cfg)\n", - "\n", - " for step in range(cfg['epochs'] * n_samples // cfg['batch_size']):\n", - " _, _, img, _ = next(train_batches)\n", - "\n", - " rng, key = jax.random.split(rng)\n", - " params, opt_state, total_loss, recon_loss, kl_loss = update_train(params, _, img, _, opt_state, key )\n", - " imgs = imgs_fn(params, train_eval_img, img_sample, z_seed)\n", - " info = dynamic_info(total_loss, recon_loss, kl_loss, step, cfg)\n", - " syrkis.plot.multiples(jnp.concatenate(imgs, axis=0), figsize=(2, 4), info={'top': info})\n", - " \"\"\" if monitor and step % (eval_step * 10) == 0:\n", - " train_loss, eval_loss = evaluate(params, train_batches, eval_batches)\n", - " wandb.log({\"train_loss\": train_loss, \"eval_loss\": eval_loss}) \"\"\"\n", - " if monitor:\n", - " wandb.finish()\n", - "\n", - "rng = jax.random.PRNGKey(0) # passing to train will make make it variational\n", - "train(params, opt_state, cfg, train_batches, eval_batches, rng, monitor=False)" + "fmri, imgs = make_samples(subjects['subj05'], 'subj05', 100)\n", + "fmri = jnp.array(fmri)\n", + "paraply_fn = jit(vmap(apply_fn, in_axes=(None, 0)))\n", + "for i in tqdm(range(0, 100, 2)):\n", + " x, y = fmri[i:i+2], imgs[i:i+2]\n", + " y_hat = paraply_fn(params, x)\n", + " loss = loss_fn(params, x, y)\n", + " " ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 120, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "ename": "XlaRuntimeError", + "evalue": "UNKNOWN: /var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: error: failed to legalize operation 'mhlo.pad'\n/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: called from\n/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: see current operation: %89 = \"mhlo.pad\"(%88, %1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<2xsi32>, tensor) -> tensor<3xsi32>\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mXlaRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[120], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m batched_graph \u001b[38;5;241m=\u001b[39m \u001b[43mjraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfmri\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jraph/_src/utils.py:477\u001b[0m, in \u001b[0;36mbatch\u001b[0;34m(graphs)\u001b[0m\n\u001b[1;32m 424\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbatch\u001b[39m(graphs: Sequence[gn_graph\u001b[38;5;241m.\u001b[39mGraphsTuple]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m gn_graph\u001b[38;5;241m.\u001b[39mGraphsTuple:\n\u001b[1;32m 425\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a batched graph given a list of graphs.\u001b[39;00m\n\u001b[1;32m 426\u001b[0m \n\u001b[1;32m 427\u001b[0m \u001b[38;5;124;03m This method will concatenate the ``nodes``, ``edges`` and ``globals``,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[38;5;124;03m graph.\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 477\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraphs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp_\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mjnp\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jraph/_src/utils.py:489\u001b[0m, in \u001b[0;36m_batch\u001b[0;34m(graphs, np_)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Returns batched graph given a list of graphs and a numpy-like module.\"\"\"\u001b[39;00m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# Calculates offsets for sender and receiver arrays, caused by concatenating\u001b[39;00m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;66;03m# the nodes arrays.\u001b[39;00m\n\u001b[0;32m--> 489\u001b[0m offsets \u001b[38;5;241m=\u001b[39m \u001b[43mnp_\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcumsum\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mnp_\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mnp_\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_node\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mg\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mgraphs\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_map_concat\u001b[39m(nests):\n\u001b[1;32m 493\u001b[0m concat \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs: np_\u001b[38;5;241m.\u001b[39mconcatenate(args)\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:651\u001b[0m, in \u001b[0;36m_make_cumulative_reduction..cumulative_reduction\u001b[0;34m(a, axis, dtype, out)\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[38;5;129m@_wraps\u001b[39m(np_reduction, skip_params\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mout\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 649\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcumulative_reduction\u001b[39m(a: ArrayLike, axis: Axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 650\u001b[0m dtype: DTypeLike \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, out: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Array:\n\u001b[0;32m--> 651\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_cumulative_reduction\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_ensure_optional_axes\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n", + "File \u001b[0;32m~/.pyenv/versions/3.11.4/lib/python3.11/site-packages/jax/_src/dispatch.py:465\u001b[0m, in \u001b[0;36mbackend_compile\u001b[0;34m(backend, module, options, host_callbacks)\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m backend\u001b[38;5;241m.\u001b[39mcompile(built_c, compile_options\u001b[38;5;241m=\u001b[39moptions,\n\u001b[1;32m 461\u001b[0m host_callbacks\u001b[38;5;241m=\u001b[39mhost_callbacks)\n\u001b[1;32m 462\u001b[0m \u001b[38;5;66;03m# Some backends don't have `host_callbacks` option yet\u001b[39;00m\n\u001b[1;32m 463\u001b[0m \u001b[38;5;66;03m# TODO(sharadmv): remove this fallback when all backends allow `compile`\u001b[39;00m\n\u001b[1;32m 464\u001b[0m \u001b[38;5;66;03m# to take in `host_callbacks`\u001b[39;00m\n\u001b[0;32m--> 465\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackend\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuilt_c\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mXlaRuntimeError\u001b[0m: UNKNOWN: /var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: error: failed to legalize operation 'mhlo.pad'\n/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: called from\n/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: see current operation: %89 = \"mhlo.pad\"(%88, %1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<2xsi32>, tensor) -> tensor<3xsi32>\n" + ] + } + ], + "source": [ + "\n", + "batched_graph = jraph.batch(fmri)" + ] }, { "cell_type": "code", diff --git a/src/data.py b/src/data.py index f362c82..dff84d2 100644 --- a/src/data.py +++ b/src/data.py @@ -29,7 +29,7 @@ def load_split(path: str, split_idx: int, image_size: int, subject: str, precisi lh_fmri = np.load(os.path.join(path, 'training_fmri', 'lh_training_fmri.npy'))[split_idx] rh_fmri = np.load(os.path.join(path, 'training_fmri', 'rh_training_fmri.npy'))[split_idx] images, metadata = load_coco(path, split_idx, image_size, subject) - return lh_fmri.astype(precision), rh_fmri.astype(precision), images.astype(precision), metadata + return jnp.array(lh_fmri), jnp.array(rh_fmri), jnp.array(images) # , metadata def load_metadata(image_files: list, subject: str) -> list: diff --git a/src/fmri.py b/src/fmri.py index 7b99f9b..46981d0 100644 --- a/src/fmri.py +++ b/src/fmri.py @@ -43,6 +43,7 @@ def fsaverage_vec(challenge_vec, subject, roi, hem) -> np.ndarray: fsaverage_response[np.where(fsaverage_space)[0]] = challenge_vec return fsaverage_response +import numpy as np def get_bold_with_coords_and_faces(challenge_vec, subject, hem, roi=None): """ @@ -57,9 +58,20 @@ def get_bold_with_coords_and_faces(challenge_vec, subject, hem, roi=None): # Load the coordinates and faces for the selected hemisphere coords, faces = load_surf_mesh(ATLAS[side]) - # Filter the coordinates based on the fsaverage_response - # Only include coordinates where there is a non-zero response - filtered_coords = coords[np.where(fsaverage_response)[0]] + # Create a mask for non-zero fsaverage_response + response_mask = np.where(fsaverage_response)[0] + + # Filter the coordinates + filtered_coords = coords[response_mask] + + # Create a mapping from old vertex indices to new ones + index_mapping = np.full(np.max(faces) + 1, -1) # Initialize with -1 + index_mapping[response_mask] = np.arange(response_mask.size) + + # Adjust faces to new indexing and filter out invalid faces + filtered_faces = index_mapping[faces] + valid_faces_mask = np.all(filtered_faces != -1, axis=1) + filtered_faces = filtered_faces[valid_faces_mask] - # Return the filtered coordinates, corresponding BOLD signal values, and faces - return filtered_coords, fsaverage_response[np.where(fsaverage_response)[0]], faces + # Return the filtered coordinates, corresponding BOLD signal values, and adjusted faces + return filtered_coords, fsaverage_response[response_mask], filtered_faces diff --git a/src/graph.py b/src/graph.py new file mode 100644 index 0000000..5c8c15c --- /dev/null +++ b/src/graph.py @@ -0,0 +1,26 @@ +# graph.py +# neuroscope graph structure +# by: Noah Syrkis + +#imports +import jax.numpy as jnp +import jraph + +nodes_features = jnp.array([[0.], [1.], [2.]]) +senders = jnp.array([0, 1, 2]) +receivers = jnp.array([1, 2, 0]) +edges = jnp.array([[0.], [1.], [2.]]) + +n_node = jnp.array([len(nodes_features)]) +n_edge = jnp.array([len(edges)]) + +global_context = jnp.array([[0.]]) +graph = jraph.GraphsTuple(nodes=nodes_features, edges=edges, + senders=senders, receivers=receivers, + n_node=n_node, n_edge=n_edge, + globals=global_context) + +graphs = jraph.batch([graph, graph]) + +node_targets = jnp.array([[True], [False], [True]]) +graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets}) \ No newline at end of file