diff --git a/.gitignore b/.gitignore index 509187e..791ddd9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ training_history.pkl fourier_power_only.pdf notebooks/figs/ notebooks/figs_temp/ +notebooks/figures/ notebooks/saved_models/ notebooks/checkpoint* notebooks/W-weights.pdf diff --git a/notebooks/2D.ipynb b/notebooks/2D.ipynb deleted file mode 100644 index eac8f90..0000000 --- a/notebooks/2D.ipynb +++ /dev/null @@ -1,1701 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2D Sequential modular addition with a Quadratic RNN\n", - "\n", - "To do:\n", - "- Should have an option to compute loss over all intermediate compositions (this will allow us to train for longer)\n", - "- Should create a dataloader that samples online\n", - "- Better activation plots\n" - ], - "id": "14ad894a-55e4-4f86-9e22-b5260504966b" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set up" - ], - "id": "6d1d8a65" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# autoreload\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "# jupyter black formatter\n", - "%load_ext jupyter_black\n", - "\n", - "import subprocess\n", - "import os\n", - "import sys\n", - "\n", - "gitroot_path = subprocess.check_output(\n", - " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n", - ")\n", - "\n", - "os.chdir(os.path.join(gitroot_path[:-1], \"gagf\"))\n", - "print(\"Working directory: \", os.getcwd())\n", - "\n", - "sys_dir = os.path.dirname(os.getcwd())\n", - "sys.path.append(sys_dir)\n", - "print(\"Directory added to path: \", sys_dir)\n", - "sys.path.append(os.getcwd())\n", - "print(\"Directory added to path: \", os.getcwd())" - ], - "execution_count": null, - "outputs": [], - "id": "fd71a9a8" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Imports" - ], - "id": "62884979" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# Core\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "\n", - "# Torch utilities\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "\n", - "# Vision\n", - "import torchvision\n", - "from torchvision import transforms\n", - "\n", - "# Plotting\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.gridspec as gridspec\n", - "from matplotlib.patches import Rectangle, Patch\n", - "from matplotlib.ticker import MaxNLocator\n", - "from matplotlib.lines import Line2D\n", - "from matplotlib.colors import Normalize\n", - "from matplotlib.colors import LogNorm\n", - "import matplotlib.cm as cm\n", - "\n", - "# Misc\n", - "from tqdm import tqdm\n", - "from typing import Optional" - ], - "execution_count": null, - "outputs": [], - "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## RNN Architecture" - ], - "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.model import QuadraticRNN" - ], - "execution_count": null, - "outputs": [], - "id": "5f63c4dd" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Optimization" - ], - "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.train import train" - ], - "execution_count": null, - "outputs": [], - "id": "1035f81c-e877-4655-8640-4e4c3d323af8" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Plotting functions" - ], - "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.utils import (\n", - " style_axes,\n", - " get_power_2d_adele,\n", - " topk_template_freqs,\n", - ")" - ], - "execution_count": null, - "outputs": [], - "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2D Dataset" - ], - "id": "54bef8ae" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.datamodule import (\n", - " build_modular_addition_sequence_dataset_2d,\n", - " mnist_template_2d,\n", - " generate_template_unique_freqs,\n", - ")" - ], - "execution_count": null, - "outputs": [], - "id": "2aa4c3fd" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experiment Setup" - ], - "id": "f1c0453e" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.utils import plot_2d_power_spectrum, plot_2d_signal, get_power_2d\n", - "\n", - "# Set seed\n", - "np.random.seed(5)\n", - "torch.manual_seed(5)\n", - "\n", - "# 2D dimensions\n", - "p1, p2 = 28, 28 # Can be different, but start with square\n", - "p_flat = p1 * p2\n", - "\n", - "# Generate 2D template\n", - "template_2d = generate_template_unique_freqs(p1, p2, n_freqs=10)\n", - "\n", - "# Mean center template\n", - "template_2d = template_2d - np.mean(template_2d)\n", - "\n", - "# Visualize template and its spectrum\n", - "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", - "plot_2d_signal(axes[0], template_2d, title=\"2D Template\", cmap=\"grey\")\n", - "power_2d, fx, fy = get_power_2d(template_2d)\n", - "plot_2d_power_spectrum(axes[1], power_2d, fx, fy, title=\"Template Power Spectrum\")\n", - "plt.tight_layout()\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "d424edf7-ad00-4836-a57c-6de2b2e06a63" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# Build sequence data\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "k = 3 # sequence length\n", - "mode = \"sampled\"\n", - "# TEST_MODE: Reduce num_samples for faster automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "num_samples = 1000 if TEST_MODE else 100000 # Reduced in test mode\n", - "\n", - "X_seq_2d, Y_seq_2d, sequence_xy = build_modular_addition_sequence_dataset_2d(\n", - " p1, p2, template_2d, k, mode=mode, num_samples=num_samples\n", - ")\n", - "\n", - "# Convert to torch tensors\n", - "X_seq_2d_t = torch.tensor(X_seq_2d, dtype=torch.float32, device=device)\n", - "Y_seq_2d_t = torch.tensor(Y_seq_2d, dtype=torch.float32, device=device)\n", - "\n", - "print(f\"Dataset shapes:\")\n", - "print(f\" X: {X_seq_2d_t.shape} (N, k, p1*p2)\")\n", - "print(f\" Y: {Y_seq_2d_t.shape} (N, p1*p2)\")\n", - "print(f\" Flattened dimension: {p_flat}\")" - ], - "execution_count": null, - "outputs": [], - "id": "1a643405" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Train 2D model" - ], - "id": "624d77d9" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.main import main" - ], - "execution_count": null, - "outputs": [], - "id": "4b53c4ac" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "config = {\n", - " \"data\": {\n", - " \"p1\": p1,\n", - " \"p2\": p2,\n", - " \"k\": k,\n", - " \"mode\": mode,\n", - " \"num_samples\": num_samples,\n", - " \"n_freqs\": 10,\n", - " \"batch_size\": 1000,\n", - " \"seed\": 5,\n", - " },\n", - " \"model\": {\n", - " \"hidden_dim\": 36,\n", - " \"init_scale\": 1.0e-5,\n", - " },\n", - " \"training\": {\n", - " \"epochs\": 5,\n", - " \"learning_rate\": 0.0001,\n", - " \"weight_decay\": 0.0,\n", - " \"betas\": [0.9, 0.999],\n", - " \"grad_clip\": 0.1,\n", - " \"verbose_interval\": 5,\n", - " },\n", - " \"device\": \"cuda\",\n", - "}" - ], - "execution_count": null, - "outputs": [], - "id": "474b1423" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "main(config)" - ], - "execution_count": null, - "outputs": [], - "id": "38b78194" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from gagf.rnns.train import train\n", - "\n", - "# TEST_MODE: Reduce batch_size and hidden_dim for faster automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "batch_size = 100 if TEST_MODE else 1000 # Reduced in test mode\n", - "seq_dataset_2d = TensorDataset(X_seq_2d_t, Y_seq_2d_t)\n", - "seq_loader_2d = DataLoader(\n", - " seq_dataset_2d, batch_size=batch_size, shuffle=True, drop_last=False\n", - ")\n", - "\n", - "# Model - note p is now p1*p2\n", - "template_torch = torch.tensor(template_2d, device=device, dtype=torch.float32).flatten()\n", - "hidden_dim_2d = 12 if TEST_MODE else 36 # Reduced in test mode\n", - "rnn_2d = QuadraticRNN(\n", - " p=p_flat,\n", - " d=hidden_dim_2d,\n", - " template=template_torch,\n", - " init_scale=1e-5,\n", - ").to(device)\n", - "\n", - "criterion = nn.MSELoss()\n", - "\n", - "weight_decay = 0 # relevant for W_mix plot\n", - "learning_rate = 0.0001\n", - "\n", - "optimizer = optim.Adam(\n", - " rnn_2d.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay\n", - ")\n", - "\n", - "# TEST_MODE: Set to reduce epochs for automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "epochs = 2 if TEST_MODE else 5\n", - "\n", - "loss_hist_2d, acc_hist_2d, param_hist_2d = train(\n", - " rnn_2d,\n", - " seq_loader_2d,\n", - " criterion,\n", - " optimizer,\n", - " epochs=epochs,\n", - " verbose_interval=max(1, epochs),\n", - " grad_clip=0.1,\n", - ")" - ], - "execution_count": null, - "outputs": [], - "id": "31005b99" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loss Plot" - ], - "id": "aaf7af30-5f5e-4ea5-8425-1da1efc1fa69" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# --- Plot RNN training loss only ---\n", - "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", - "\n", - "# x-axis = epochs 1..T\n", - "x = np.arange(1, len(loss_hist_2d) + 1)\n", - "ax.plot(x, loss_hist_2d, lw=4)\n", - "\n", - "# === Compute power spectrum of template ===\n", - "x_freq, y_freq, power = get_power_2d_adele(template_2d)\n", - "power = power.flatten()\n", - "valid = power > 1e-20\n", - "power = power[valid]\n", - "sorted_idx = np.argsort(power)[::-1] # np.argsort with [::-1] gives descending order\n", - "power = power[sorted_idx]\n", - "# print(\"Power in x: {}\".format(power))\n", - "\n", - "# Plot theoretical lines\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / (p1 * p2)\n", - "for k, alpha in enumerate(alpha_values):\n", - " ax.axhline(y=coef * alpha, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", - "\n", - "\n", - "# # Steps\n", - "# steps = [1, 20, 50, len(param_hist) - 1] # make sure these are < len(rnn_param_hist)\n", - "# for step in steps:\n", - "# ax.axvline(step, c=\"k\")\n", - "\n", - "# plot in log scale\n", - "# ax.set_xscale(\"log\")\n", - "# ax.set_yscale(\"log\")\n", - "# ax.set_ylim(3.5e-2, 1.3e-1)\n", - "\n", - "ax.set_xlabel(\"Epochs\", fontsize=24)\n", - "ax.set_ylabel(\"Train Loss\", fontsize=24)\n", - "\n", - "style_axes(ax)\n", - "\n", - "plt.grid(False)\n", - "plt.tight_layout()\n", - "# plt.savefig(\"rnn-loss.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "d1037a02" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prediction Plot" - ], - "id": "d4f01b69" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "\n", - "# --- config ---\n", - "steps_2d = [1, 5] # , 10, len(param_hist_2d) - 1]\n", - "example_idx = int(np.random.randint(len(Y_seq_2d_t)))\n", - "cmap = \"gray\"\n", - "origin = \"upper\" # use \"lower\" if you prefer cartesian orientation\n", - "\n", - "device = next(rnn_2d.parameters()).device\n", - "rnn_2d.to(device).eval()\n", - "\n", - "# --- collect preds for the SAME example across steps ---\n", - "preds = []\n", - "with torch.no_grad():\n", - " truth_2d = Y_seq_2d_t[example_idx].reshape(p1, p2).cpu().numpy()\n", - "\n", - "for step in steps_2d:\n", - " rnn_2d.load_state_dict(param_hist_2d[step], strict=True)\n", - " with torch.no_grad():\n", - " x = X_seq_2d_t[example_idx : example_idx + 1].to(device) # (1, k, p1*p2)\n", - " pred_2d = rnn_2d(x).reshape(p1, p2).detach().cpu().numpy()\n", - " preds.append(pred_2d)\n", - "\n", - "# --- shared color scale ---\n", - "vmin = np.min(truth_2d) # min(np.min(truth_2d), *(np.min(p) for p in preds))\n", - "vmax = np.max(truth_2d) # max(np.max(truth_2d), *(np.max(p) for p in preds))\n", - "\n", - "# --- plot: rows = [Prediction, Target], cols = time steps ---\n", - "fig, axes = plt.subplots(\n", - " 2, len(steps_2d), figsize=(3.5 * len(steps_2d), 6), layout=\"constrained\"\n", - ")\n", - "\n", - "for col, (step, pred_2d) in enumerate(zip(steps_2d, preds)):\n", - " im = axes[0, col].imshow(pred_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin=origin)\n", - " axes[0, col].set_title(f\"t = {step}\", fontsize=12)\n", - " axes[0, col].set_xticks([])\n", - " axes[0, col].set_yticks([])\n", - "\n", - " axes[1, col].imshow(truth_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin=origin)\n", - " axes[1, col].set_xticks([])\n", - " axes[1, col].set_yticks([])\n", - "\n", - "axes[0, 0].set_ylabel(\"Prediction\", fontsize=12)\n", - "axes[1, 0].set_ylabel(\"Target\", fontsize=12)\n", - "\n", - "# single shared colorbar on the right\n", - "fig.colorbar(im, ax=axes, location=\"right\", shrink=0.9, pad=0.02).set_label(\n", - " \"value\", fontsize=12\n", - ")\n", - "\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "11e29014" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Prediction Power Spectrum Plot" - ], - "id": "b744c04c" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "## 2D Power Spectrum Over Time (top-K template freqs)\n", - "\n", - "# --- how many frequencies to track (includes DC if it ranks in top-K) ---\n", - "num_freqs_to_track = 10 # change as you like\n", - "\n", - "# --- pick top-K frequencies from the template's 2D power (non-redundant half-plane) ---\n", - "template_power_2d, _, _ = get_power_2d(template_2d) # shape (p1, p2)\n", - "tracked_freqs = topk_template_freqs(template_2d, K=num_freqs_to_track)\n", - "target_powers = {(kx, ky): template_power_2d[kx, ky] for (kx, ky) in tracked_freqs}\n", - "\n", - "# --- choose analysis steps (log-spaced) ---\n", - "num_points = 100\n", - "num_samples = 100\n", - "T = len(param_hist_2d)\n", - "steps_2d_analysis = np.unique(\n", - " np.logspace(0, np.log10(max(T - 1, 1)), num_points, dtype=int)\n", - ")\n", - "steps_2d_analysis = steps_2d_analysis[steps_2d_analysis < T]\n", - "steps_2d_analysis = np.insert(steps_2d_analysis, 0, 0)\n", - "\n", - "# --- track average output power at those frequencies over training ---\n", - "powers_over_time_2d = {freq: [] for freq in tracked_freqs}\n", - "\n", - "with torch.no_grad():\n", - " for step in tqdm(steps_2d_analysis, desc=\"Computing power spectra\"):\n", - " rnn_2d.load_state_dict(param_hist_2d[step], strict=True)\n", - " rnn_2d.eval()\n", - "\n", - " outputs_flat = (\n", - " rnn_2d(X_seq_2d_t[:num_samples].to(device)).detach().cpu().numpy()\n", - " ) # (N, p1*p2)\n", - "\n", - " # average power over batch\n", - " # (simple loop for clarity; vectorize later if needed)\n", - " powers_batch = []\n", - " for i in range(outputs_flat.shape[0]):\n", - " out_2d = outputs_flat[i].reshape(p1, p2)\n", - " power_i, _, _ = get_power_2d(out_2d) # (p1, p2)\n", - " powers_batch.append(power_i)\n", - " avg_power = np.mean(powers_batch, axis=0) # (p1, p2)\n", - "\n", - " for kx, ky in tracked_freqs:\n", - " powers_over_time_2d[(kx, ky)].append(avg_power[kx, ky])\n", - "\n", - "# --- convert lists to arrays ---\n", - "for freq in tracked_freqs:\n", - " powers_over_time_2d[freq] = np.array(powers_over_time_2d[freq])" - ], - "execution_count": null, - "outputs": [], - "id": "43a6b047" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Plot prediction power spectrum over time" - ], - "id": "cbfd47eb" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from matplotlib.ticker import FormatStrFormatter\n", - "\n", - "# Colors used for both the bands (top) and the mode curves (bottom)\n", - "colors_2d = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - "\n", - "# Make the figure a bit shorter and reduce vertical gap between plots\n", - "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)\n", - "fig.subplots_adjust(left=0.12, right=0.98, top=0.96, bottom=0.10, hspace=0.12)\n", - "\n", - "# ---------------- Top: training loss + theory lines ----------------\n", - "x = np.arange(1, len(loss_hist_2d) + 1)\n", - "ax1.plot(x, loss_hist_2d, lw=4, zorder=5)\n", - "\n", - "# Compute power spectrum of template for theory lines\n", - "fy, fx, power = get_power_2d_adele(template_2d) # adapt unpacking if your func differs\n", - "power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - "\n", - "# Theory levels (cumulative tail sums)\n", - "alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n", - "coef = 1.0 / (p1 * p2)\n", - "y_levels = coef * alpha_values # strictly decreasing\n", - "\n", - "# Shade horizontal bands between successive theory lines\n", - "n_bands = min(len(tracked_freqs), len(y_levels) - 1)\n", - "for i in range(n_bands):\n", - " y_top = y_levels[i]\n", - " y_bot = y_levels[i + 1]\n", - " ax1.axhspan(y_bot, y_top, facecolor=colors_2d[i], alpha=0.15, zorder=-3)\n", - "\n", - "# Draw the black theory lines\n", - "for y in y_levels[: n_bands + 1]:\n", - " ax1.axhline(y=y, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", - "\n", - "# ax1.set_ylim(3.5e-2, 1.3e-1)\n", - "ax1.set_ylabel(\"Train Loss\", fontsize=24)\n", - "style_axes(ax1)\n", - "ax1.grid(False)\n", - "ax1.tick_params(labelbottom=False) # only show x ticks on bottom plot\n", - "\n", - "# ---------------- Bottom: tracked mode power over time ----------------\n", - "for i, (kx, ky) in enumerate(tracked_freqs):\n", - " ax2.plot(steps_2d_analysis, powers_over_time_2d[(kx, ky)], color=colors_2d[i], lw=3)\n", - " ax2.axhline(\n", - " target_powers[(kx, ky)],\n", - " color=colors_2d[i],\n", - " linestyle=\"dotted\",\n", - " linewidth=2,\n", - " alpha=0.5,\n", - " )\n", - "\n", - "ax2.set_xlabel(\"Epochs\", fontsize=20)\n", - "ax2.set_ylabel(\"Power in Prediction\", fontsize=24)\n", - "ax2.grid(True, alpha=0.3)\n", - "style_axes(ax2)\n", - "ax2.yaxis.set_major_formatter(FormatStrFormatter(\"%.1f\"))\n", - "\n", - "# No legend; plots are closer via reduced hspace and smaller figure height\n", - "plt.savefig(\"train-loss.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "6ba47465-6ab8-41dd-8b2d-242a27217866" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Plot reference Fourier modes (irreps of G)" - ], - "id": "2f2b63f9" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import os\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.patheffects as pe\n", - "\n", - "# Expected to exist:\n", - "# - tracked_freqs : list[tuple[int,int]]\n", - "# - p1, p2 : ints\n", - "\n", - "# Colors\n", - "try:\n", - " colors_2d\n", - "except NameError:\n", - " colors_2d = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - "\n", - "\n", - "def fourier_mode_2d(p1: int, p2: int, kx: int, ky: int, phase: float = 0.0):\n", - " y = np.arange(p1)[:, None]\n", - " x = np.arange(p2)[None, :]\n", - " mode = np.cos(2 * np.pi * (ky * y / p1 + kx * x / p2) + phase)\n", - " mmin, mmax = mode.min(), mode.max()\n", - " return (mode - mmin) / (mmax - mmin) if mmax > mmin else mode\n", - "\n", - "\n", - "def signed_k(k: int, n: int) -> int:\n", - " return k if k <= n // 2 else k - n\n", - "\n", - "\n", - "def pretty_k(k: int, n: int) -> str:\n", - " if n % 2 == 0 and k == n // 2:\n", - " return r\"\\pm{}\".format(n // 2)\n", - " return f\"{signed_k(k, n)}\"\n", - "\n", - "\n", - "# ---- Save each mode separately ----\n", - "out_dir = \"fourier_modes\"\n", - "os.makedirs(out_dir, exist_ok=True)\n", - "\n", - "for i, (kx, ky) in enumerate(tracked_freqs):\n", - " img = fourier_mode_2d(p1, p2, kx, ky)\n", - "\n", - " fig, ax = plt.subplots(figsize=(3.2, 2.2))\n", - " ax.imshow(img, cmap=\"RdBu_r\", origin=\"upper\")\n", - " ax.set_xticks([])\n", - " ax.set_yticks([])\n", - "\n", - " # Colored border\n", - " for side in (\"left\", \"right\", \"top\", \"bottom\"):\n", - " ax.spines[side].set_edgecolor(colors_2d[i])\n", - " ax.spines[side].set_linewidth(8)\n", - "\n", - " # Big, colored, centered label\n", - " kx_label = pretty_k(kx, p2)\n", - " ky_label = pretty_k(ky, p1)\n", - " ax.text(\n", - " 0.5,\n", - " 0.5,\n", - " f\"$k=({kx_label},{ky_label})$\",\n", - " color=colors_2d[i],\n", - " fontsize=25,\n", - " fontweight=\"bold\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " transform=ax.transAxes,\n", - " path_effects=[pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.8)],\n", - " )\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # Filenames include both raw and signed indices\n", - " kx_signed, ky_signed = signed_k(kx, p2), signed_k(ky, p1)\n", - " base = f\"mode_{i:03d}_kx{kx}_ky{ky}_signed_{kx_signed}_{ky_signed}\"\n", - " png_path = os.path.join(out_dir, base + \".png\")\n", - " npy_path = os.path.join(out_dir, base + \".npy\")\n", - "\n", - " fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n", - " plt.close(fig)\n", - "\n", - " # Optional: save the normalized array too\n", - " np.save(npy_path, img)\n", - "\n", - "print(f\"Saved {len(tracked_freqs)} mode images (and .npy arrays) to: {out_dir}/\")" - ], - "execution_count": null, - "outputs": [], - "id": "6f3db659-353d-40b2-94cf-37ed2153f0b0" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# ---- Save one tall, vertically stacked image (labels on RIGHT) with extra whitespace\n", - "# and safe padding so thick borders aren't trimmed ----\n", - "stack_name = \"fourier_modes_stacked_rightlabel_spaced.png\" # or \".pdf\"\n", - "stack_path = os.path.join(out_dir, stack_name)\n", - "\n", - "n = len(tracked_freqs)\n", - "\n", - "# Panel geometry and gap (increase gap_h_in for more whitespace)\n", - "panel_h_in = 2.2\n", - "gap_h_in = 0.35 # whitespace between rows\n", - "# total width = image col + label col\n", - "fig_w_in = 4.6\n", - "fig_h_in = n * panel_h_in + (n - 1) * gap_h_in\n", - "\n", - "import matplotlib.gridspec as gridspec\n", - "\n", - "fig = plt.figure(figsize=(fig_w_in, fig_h_in), dpi=300)\n", - "\n", - "# Rows alternate: [panel, gap, panel, gap, ..., panel]\n", - "rows = 2 * n - 1\n", - "height_ratios = []\n", - "for i in range(n):\n", - " height_ratios.append(panel_h_in)\n", - " if i < n - 1:\n", - " height_ratios.append(gap_h_in)\n", - "\n", - "# Right-labeled: image on LEFT, label on RIGHT\n", - "gs = gridspec.GridSpec(\n", - " nrows=rows, ncols=2,\n", - " width_ratios=[1.0, 0.46], # image : label (bump right col if labels are long)\n", - " height_ratios=height_ratios,\n", - " wspace=0.0, hspace=0.0\n", - ")\n", - "\n", - "for i, (kx, ky) in enumerate(tracked_freqs):\n", - " r = 2 * i # even rows are content; odd rows are spacers\n", - "\n", - " # Image axis (left)\n", - " ax_img = fig.add_subplot(gs[r, 0])\n", - " img = fourier_mode_2d(p1, p2, kx, ky)\n", - " ax_img.imshow(img, cmap=\"RdBu_r\", origin=\"upper\", aspect=\"equal\")\n", - " ax_img.set_xticks([]); ax_img.set_yticks([])\n", - "\n", - " # Colored border around the image only\n", - " for side in (\"left\", \"right\", \"top\", \"bottom\"):\n", - " ax_img.spines[side].set_edgecolor(colors_2d[i])\n", - " ax_img.spines[side].set_linewidth(8)\n", - "\n", - " # Label axis (right)\n", - " ax_label = fig.add_subplot(gs[r, 1])\n", - " ax_label.set_axis_off()\n", - " kx_label = pretty_k(kx, p2)\n", - " ky_label = pretty_k(ky, p1)\n", - " ax_label.text(\n", - " 0.0, 0.5, f\"$k=({kx_label},{ky_label})$\",\n", - " color=colors_2d[i], fontsize=45, fontweight=\"bold\",\n", - " ha=\"left\", va=\"center\", transform=ax_label.transAxes,\n", - " path_effects=[pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.8)]\n", - " )\n", - "\n", - "# Pull content slightly in from the canvas edges so spines aren't flush\n", - "fig.subplots_adjust(left=0.02, right=0.98, top=0.985, bottom=0.015)\n", - "\n", - "# Save with a tiny pad so thick borders aren't clipped by 'tight' bbox\n", - "fig.savefig(stack_path, dpi=300, bbox_inches=\"tight\", pad_inches=0.12)\n", - "plt.close(fig)\n", - "\n", - "print(f\"Saved vertical stack with RIGHT labels and spacing (no edge clipping): {stack_path}\")" - ], - "execution_count": null, - "outputs": [], - "id": "996520ee-f19f-48a4-b453-a22274a91661" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## W_out Plot" - ], - "id": "5932a60e-e3d1-4011-946e-54dfae1a62f6" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# ========= Inputs you already have =========\n", - "# param_hist_2d : list of state-dicts with key \"W_out\"\n", - "# p1, p2 : reshape dims (p1 * p2 == D)\n", - "\n", - "# ========= Config =========\n", - "steps_2d = [1, 5] # 100, len(param_hist_2d) - 1]\n", - "dead_thresh_l2 = 0.25 # absolute L2-norm threshold for \"dead\" neurons\n", - "heat_cmap = \"RdBu_r\" # colormap for heatmaps (weight values)\n", - "border_lw = 5.0 # border width for neuron tiles\n", - "title_fs = 18 # suptitle font size\n", - "\n", - "# Colors for categories\n", - "palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - "dead_color = (0.6, 0.6, 0.6, 1.0)\n", - "\n", - "\n", - "# ========= Tiny helpers =========\n", - "def squareish_grid(n):\n", - " c = int(np.ceil(np.sqrt(n)))\n", - " r = int(np.ceil(n / c))\n", - " return r, c\n", - "\n", - "\n", - "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n", - " \"\"\"Sum power at (kx,ky) and its real-signal mirror (-kx,-ky).\"\"\"\n", - " i0, j0 = kx % p1, ky % p2\n", - " i1, j1 = (-kx) % p1, (-ky) % p2\n", - " if (i0, j0) == (i1, j1):\n", - " return power2d[i0, j0]\n", - " return power2d[i0, j0] + power2d[i1, j1]\n", - "\n", - "\n", - "# ========= Prep: sizes and global color limits =========\n", - "W0 = param_hist_2d[steps_2d[0]][\"W_out\"].detach().cpu().numpy().T # (H, D)\n", - "H, D = W0.shape\n", - "assert p1 * p2 == D, f\"p1*p2 ({p1*p2}) must equal D ({D}).\"\n", - "\n", - "vmin, vmax = np.inf, -np.inf\n", - "for step in steps_2d:\n", - " W = param_hist_2d[step][\"W_out\"].detach().cpu().numpy().T\n", - " vmin = min(vmin, W.min())\n", - " vmax = max(vmax, W.max())\n", - "\n", - "# ========= One figure per time step =========\n", - "R_ner, C_ner = squareish_grid(H)\n", - "tile_w, tile_h = 2, 2 # inches per neuron tile\n", - "figsize = (C_ner * tile_w, R_ner * tile_h)\n", - "\n", - "for step in steps_2d:\n", - " W = param_hist_2d[step][\"W_out\"].detach().cpu().numpy().T # (H, D)\n", - "\n", - " # Dominant tracked frequency + dead mask\n", - " dom_idx = np.empty(H, dtype=int)\n", - " l2 = np.linalg.norm(W, axis=1)\n", - " dead_mask = l2 < dead_thresh_l2\n", - "\n", - " for j in range(H):\n", - " m = W[j].reshape(p1, p2)\n", - " F = np.fft.fft2(m)\n", - " P = (F.conj() * F).real\n", - " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n", - " dom_idx[j] = int(np.argmax(tp))\n", - "\n", - " edge_colors = palette[dom_idx]\n", - " edge_colors[dead_mask] = dead_color\n", - "\n", - " # Build figure for this time step\n", - " fig = plt.figure(figsize=figsize)\n", - " # fig.suptitle(f\"step {step}\", fontsize=title_fs, fontweight=\"bold\", y=0.98)\n", - " gs = gridspec.GridSpec(R_ner, C_ner, figure=fig, wspace=0.06, hspace=0.06)\n", - "\n", - " # Plot tiles\n", - " for j in range(R_ner * C_ner):\n", - " ax = fig.add_subplot(gs[j // C_ner, j % C_ner])\n", - " if j < H:\n", - " m = W[j].reshape(p1, p2)\n", - " im = ax.imshow(\n", - " m, vmin=vmin, vmax=vmax, origin=\"lower\", aspect=\"equal\", cmap=heat_cmap\n", - " )\n", - " # border highlight\n", - " ec = edge_colors[j]\n", - " for sp in ax.spines.values():\n", - " sp.set_edgecolor(ec)\n", - " sp.set_linewidth(border_lw)\n", - " else:\n", - " ax.axis(\"off\")\n", - "\n", - " ax.set_xticks([])\n", - " ax.set_yticks([])\n", - " plt.savefig(f\"step {step}\", bbox_inches=\"tight\", dpi=200)\n", - " plt.show()\n", - "\n", - "# ========= Standalone GLOBAL COLORBAR figure =========\n", - "fig_cb = plt.figure(figsize=(4, 1.0)) # wide short bar\n", - "ax_cb = fig_cb.add_axes([0.1, 0.35, 0.8, 0.3]) # [left, bottom, width, height]\n", - "norm = Normalize(vmin=vmin, vmax=vmax)\n", - "sm = cm.ScalarMappable(norm=norm, cmap=heat_cmap)\n", - "cbar = fig_cb.colorbar(sm, cax=ax_cb, orientation=\"horizontal\")\n", - "cbar.set_label(\"Weight value\")\n", - "plt.show()\n", - "\n", - "# ========= Standalone LEGEND figure for dominant frequency =========\n", - "fig_legend = plt.figure(figsize=(5, 1.6))\n", - "ax_leg = fig_legend.add_subplot(111)\n", - "ax_leg.axis(\"off\")\n", - "\n", - "# Use patches with COLORED EDGES (to match tile borders)\n", - "handles = [\n", - " Patch(\n", - " facecolor=\"white\", edgecolor=palette[i], linewidth=2.5, label=tracked_freqs[i]\n", - " )\n", - " for i in range(len(tracked_freqs))\n", - "]\n", - "handles.append(\n", - " Patch(facecolor=\"white\", edgecolor=dead_color, linewidth=2.5, label=\"dead\")\n", - ")\n", - "\n", - "leg = ax_leg.legend(\n", - " handles=handles, ncol=4, frameon=True, loc=\"center\", title=\"Dominant frequency\"\n", - ")\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "44adb93e-a396-4b7c-8350-056934a49ddd" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## W_mix Plot" - ], - "id": "d3232d4a-d65e-4282-be33-00da055ebb98" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# =========================\n", - "# Config\n", - "# =========================\n", - "within_group_order = \"phase\" # 'phase' | 'power' | 'phase_power' | 'none'\n", - "outfile = \"Wmix_grouped_by_Wo_freq_labeled.pdf\"\n", - "cmap = \"RdBu_r\"\n", - "\n", - "TITLE_FONTSIZE = 18\n", - "LABEL_FONTSIZE = 11\n", - "CBAR_LABEL_SIZE = 12\n", - "CBAR_TICK_SIZE = 11\n", - "SEPARATOR_LW = 0.9\n", - "BLOCK_EDGE_LW = 2.0\n", - "\n", - "dead_l2_thresh = 0.1\n", - "dead_position = \"last\"\n", - "\n", - "# how many template freqs to track/label\n", - "num_freqs_to_track = 7 # e.g., like your old list length\n", - "\n", - "# You must have template_2d (shape p1 x p2) defined.\n", - "p1, p2 = template_2d.shape\n", - "\n", - "\n", - "# =========================\n", - "# Helpers (2D, top-K tracked)\n", - "# =========================\n", - "def topk_template_freqs(template_2d: np.ndarray, K: int, min_power: float = 1e-20):\n", - " # Use your get_power_2d_adele(template_2d) -> (freqs_u, freqs_v, power)\n", - " freqs_u, freqs_v, power = get_power_2d_adele(\n", - " template_2d\n", - " ) # power: (p1, p2//2+1) or similar\n", - " shp = power.shape\n", - " flat = power.ravel()\n", - " mask = flat > min_power\n", - " if not np.any(mask):\n", - " return []\n", - " top_idx = np.flatnonzero(mask)[np.argsort(flat[mask])[::-1]][:K]\n", - " kx, ky = np.unravel_index(top_idx, shp)\n", - " return list(zip(kx.tolist(), ky.tolist()))\n", - "\n", - "\n", - "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n", - " i0, j0 = kx % p1, ky % p2\n", - " i1, j1 = (-kx) % p1, (-ky) % p2\n", - " if (i0, j0) == (i1, j1):\n", - " return float(power2d[i0, j0])\n", - " return float(power2d[i0, j0] + power2d[i1, j1])\n", - "\n", - "\n", - "def analyze_Wo_tracked(sd, tracked_freqs, p1, p2):\n", - " \"\"\"\n", - " For each neuron row of W_o, find the dominant freq among tracked_freqs.\n", - " Returns:\n", - " dom_idx (index into tracked_freqs), phase (at rep bin), dom_power, l2, D\n", - " \"\"\"\n", - " Wo = sd[\"W_out\"].detach().cpu().numpy() # (p, H) with p = p1*p2\n", - " W = Wo.T # (H, p)\n", - " H, D = W.shape\n", - " assert D == p1 * p2\n", - "\n", - " dom_idx = np.empty(H, dtype=int)\n", - " dom_pow = np.empty(H, dtype=float)\n", - " phase = np.empty(H, dtype=float)\n", - " l2 = np.linalg.norm(W, axis=1)\n", - "\n", - " for j in range(H):\n", - " m = W[j].reshape(p1, p2)\n", - " F = np.fft.fft2(m)\n", - " P = (F.conj() * F).real\n", - " # power only at tracked bins (with symmetry accounted)\n", - " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n", - " jj = int(np.argmax(tp))\n", - " dom_idx[j] = jj\n", - " # phase at representative positive bin\n", - " i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2\n", - " phase[j] = np.angle(F[i0, j0])\n", - " dom_pow[j] = tp[jj]\n", - " return dom_idx, phase, dom_pow, l2, D\n", - "\n", - "\n", - "def permutation_from_groups_with_dead(\n", - " dom_idx,\n", - " phase,\n", - " dom_power,\n", - " l2,\n", - " *,\n", - " within=\"phase\",\n", - " dead_l2_thresh=1e-1,\n", - " dead_position=\"last\",\n", - "):\n", - " dead_mask = l2 < float(dead_l2_thresh)\n", - " groups = {}\n", - " for i, f in enumerate(dom_idx):\n", - " key = -1 if dead_mask[i] else int(f) # f is index into tracked_freqs\n", - " groups.setdefault(key, []).append(i)\n", - "\n", - " freq_keys = sorted([k for k in groups.keys() if k >= 0])\n", - " ordered_keys = (freq_keys + [-1]) if dead_position == \"last\" else ([-1] + freq_keys)\n", - " group_keys = [k for k in ordered_keys if k in groups]\n", - "\n", - " perm, boundaries = [], []\n", - " for f in group_keys:\n", - " idxs = groups[f]\n", - " if f == -1:\n", - " idxs = sorted(idxs, key=lambda i: l2[i])\n", - " else:\n", - " if within == \"phase\" and phase is not None:\n", - " idxs = sorted(idxs, key=lambda i: (phase[i] + 2 * np.pi) % (2 * np.pi))\n", - " elif within == \"power\" and dom_power is not None:\n", - " idxs = sorted(idxs, key=lambda i: -dom_power[i])\n", - " elif (\n", - " within == \"phase_power\" and phase is not None and dom_power is not None\n", - " ):\n", - " idxs = sorted(\n", - " idxs,\n", - " key=lambda i: ((phase[i] + 2 * np.pi) % (2 * np.pi), -dom_power[i]),\n", - " )\n", - " perm.extend(idxs)\n", - " boundaries.append(len(perm))\n", - " return np.array(perm, dtype=int), group_keys, boundaries\n", - "\n", - "\n", - "def reorder_square(M, perm):\n", - " return M[perm][:, perm]\n", - "\n", - "\n", - "# labels & colors for the tracked list\n", - "tracked_freqs = topk_template_freqs(template_2d, num_freqs_to_track)\n", - "tracked_labels = [\n", - " (\"DC\" if (kx, ky) == (0, 0) else f\"({kx},{ky})\") for (kx, ky) in tracked_freqs\n", - "]\n", - "\n", - "\n", - "def build_freq_colors_tracked(n):\n", - " return plt.cm.tab10(np.linspace(0, 1, n))\n", - "\n", - "\n", - "freq_colors = build_freq_colors_tracked(len(tracked_freqs))\n", - "dead_gray = \"0.35\"\n", - "\n", - "\n", - "def add_group_labels_top(\n", - " ax, group_keys, boundaries, *, show_counts=True, rotation=0, fontsize=LABEL_FONTSIZE\n", - "):\n", - " starts = [0] + boundaries[:-1]\n", - " ends = [b - 1 for b in boundaries]\n", - " centers = [(s + e) / 2.0 for s, e in zip(starts, ends)]\n", - " sizes = [e - s + 1 for s, e in zip(starts, ends)]\n", - "\n", - " labels = []\n", - " colors = []\n", - " for kk, nn in zip(group_keys, sizes):\n", - " if kk == -1:\n", - " base = \"DEAD\"\n", - " clr = dead_gray\n", - " else:\n", - " base = tracked_labels[kk]\n", - " clr = freq_colors[kk]\n", - " labels.append(f\"{base}\\n(n={nn})\" if show_counts else base)\n", - " colors.append(clr)\n", - "\n", - " ax.set_xticks(centers)\n", - " ax.set_xticklabels(labels, rotation=rotation, fontsize=fontsize, ha=\"center\")\n", - " ax.tick_params(\n", - " axis=\"x\",\n", - " bottom=False,\n", - " top=True,\n", - " labelbottom=False,\n", - " labeltop=True,\n", - " labelsize=fontsize,\n", - " )\n", - " for lbl, clr in zip(ax.get_xticklabels(), colors):\n", - " lbl.set_color(clr)\n", - "\n", - "\n", - "# =========================\n", - "# Prepare steps & snapshots\n", - "# =========================\n", - "steps = [1, 5] # [50, 100, len(param_hist_2d) - 1]\n", - "Wh_perm_list = []\n", - "group_info_list = []\n", - "\n", - "for s in steps:\n", - " sd = param_hist_2d[s]\n", - "\n", - " # analyze W_o against tracked 2D freqs\n", - " dom_idx, phase, dom_power, l2, D = analyze_Wo_tracked(sd, tracked_freqs, p1, p2)\n", - "\n", - " # W_mix fallback to W_h\n", - " if \"W_mix\" in sd:\n", - " M = sd[\"W_mix\"].detach().cpu().numpy()\n", - " elif \"W_h\" in sd:\n", - " M = sd[\"W_h\"].detach().cpu().numpy()\n", - " else:\n", - " raise KeyError(\"Neither 'W_mix' nor 'W_h' found in state dict.\")\n", - "\n", - " perm, group_keys, boundaries = permutation_from_groups_with_dead(\n", - " dom_idx,\n", - " phase,\n", - " dom_power,\n", - " l2,\n", - " within=within_group_order,\n", - " dead_l2_thresh=dead_l2_thresh,\n", - " dead_position=dead_position,\n", - " )\n", - "\n", - " M_perm = reorder_square(M, perm)\n", - " Wh_perm_list.append(M_perm)\n", - " group_info_list.append((group_keys, boundaries))\n", - "\n", - "# Shared symmetric color limits\n", - "vmax = max(np.max(np.abs(M)) for M in Wh_perm_list)\n", - "vmin = -vmax if vmax > 0 else 0.0\n", - "\n", - "# =========================\n", - "# Plot\n", - "# =========================\n", - "n = len(steps)\n", - "fig, axes = plt.subplots(1, n, figsize=(3.8 * n, 3.8), constrained_layout=True)\n", - "try:\n", - " fig.set_constrained_layout_pads(\n", - " w_pad=0.003, h_pad=0.003, wspace=0.003, hspace=0.003\n", - " )\n", - "except Exception:\n", - " pass\n", - "if n == 1:\n", - " axes = [axes]\n", - "\n", - "im = None\n", - "for j, (s, M_perm) in enumerate(zip(steps, Wh_perm_list)):\n", - " ax = axes[j]\n", - " im = ax.imshow(\n", - " M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect=\"equal\", interpolation=\"nearest\"\n", - " )\n", - "\n", - " ax.set_yticks([])\n", - " ax.tick_params(axis=\"x\", bottom=False)\n", - "\n", - " group_keys, boundaries = group_info_list[j]\n", - " for b in boundaries[:-1]:\n", - " ax.axhline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n", - " ax.axvline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n", - "\n", - " starts = [0] + boundaries[:-1]\n", - " ends = [b - 1 for b in boundaries]\n", - " for kk, s0, e0 in zip(group_keys, starts, ends):\n", - " if kk == -1: # skip DEAD box color\n", - " continue\n", - " size = e0 - s0 + 1\n", - " rect = Rectangle(\n", - " (s0 - 0.5, s0 - 0.5),\n", - " width=size,\n", - " height=size,\n", - " fill=False,\n", - " linewidth=BLOCK_EDGE_LW,\n", - " edgecolor=freq_colors[kk],\n", - " alpha=0.95,\n", - " joinstyle=\"miter\",\n", - " )\n", - " ax.add_patch(rect)\n", - "\n", - " add_group_labels_top(\n", - " ax,\n", - " group_keys,\n", - " boundaries,\n", - " show_counts=True,\n", - " rotation=0,\n", - " fontsize=LABEL_FONTSIZE,\n", - " )\n", - " ax.set_xlabel(f\"t = {s}\", fontsize=TITLE_FONTSIZE, labelpad=8)\n", - "\n", - "# shared colorbar\n", - "cbar = fig.colorbar(im, ax=axes, shrink=1.0, pad=0.012, aspect=18)\n", - "cbar.ax.tick_params(labelsize=CBAR_TICK_SIZE)\n", - "cbar.set_label(\"Weight value\", fontsize=CBAR_LABEL_SIZE)\n", - "\n", - "plt.savefig(outfile, bbox_inches=\"tight\", dpi=200)\n", - "plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "6caf85e2-4706-4695-8499-16aad2a3061d" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# =========================\n", - "# Save tight heatmaps (keep separators + rectangles; no labels/legend)\n", - "# =========================\n", - "import os\n", - "from matplotlib.patches import Rectangle\n", - "\n", - "base, ext = os.path.splitext(outfile)\n", - "if not ext:\n", - " ext = \".png\" # change to \".pdf\" if preferred\n", - "\n", - "# Shared symmetric color limits across steps\n", - "vmax = max(np.max(np.abs(M)) for M in Wh_perm_list)\n", - "vmin = -vmax if vmax > 0 else 0.0\n", - "\n", - "for j, (s, M_perm) in enumerate(zip(steps, Wh_perm_list)):\n", - " # Full-bleed figure (no margins)\n", - " fig = plt.figure(figsize=(3.2, 3.2), dpi=200)\n", - " ax = plt.Axes(fig, [0, 0, 1, 1]) # left, bottom, width, height (normalized)\n", - " fig.add_axes(ax)\n", - " ax.set_axis_off() # no ticks/labels/frame\n", - "\n", - " # Heatmap\n", - " ax.imshow(\n", - " M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect=\"equal\", interpolation=\"nearest\"\n", - " )\n", - "\n", - " # Group separators + rectangles\n", - " group_keys, boundaries = group_info_list[j]\n", - "\n", - " # Thin separator lines between groups\n", - " for b in boundaries[:-1]:\n", - " ax.axhline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n", - " ax.axvline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n", - "\n", - " # Colored block outlines for each non-DEAD group\n", - " starts = [0] + boundaries[:-1]\n", - " ends = [b - 1 for b in boundaries]\n", - " for kk, s0, e0 in zip(group_keys, starts, ends):\n", - " if kk == -1:\n", - " continue # skip DEAD group outline\n", - " size = e0 - s0 + 1\n", - " rect = Rectangle(\n", - " (s0 - 0.5, s0 - 0.5),\n", - " width=size,\n", - " height=size,\n", - " fill=False,\n", - " linewidth=BLOCK_EDGE_LW,\n", - " edgecolor=freq_colors[kk],\n", - " alpha=0.95,\n", - " joinstyle=\"miter\",\n", - " )\n", - " ax.add_patch(rect)\n", - "\n", - " # Save ultra-tight (no padding, no legend/colorbar/labels)\n", - " per_step_outfile = f\"{base}_t{s:04d}_tight{ext}\"\n", - " fig.savefig(per_step_outfile, bbox_inches=\"tight\", pad_inches=0)\n", - " plt.close(fig)\n", - "\n", - "print(\n", - " f\"Saved {len(steps)} files like '{base}_t####_tight{ext}' (tight, no labels/legend).\"\n", - ")" - ], - "execution_count": null, - "outputs": [], - "id": "acd308eb-032e-46e2-8a55-582e2d06a793" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Phase Plot" - ], - "id": "3493ec18-299a-4642-94ec-a13723c80882" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# One figure with a POLAR SUBPLOT per module (tracked (kx, ky)), excluding DEAD and DC-dominant.\n", - "# Uses the final checkpoint only (t = -1).\n", - "\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.lines import Line2D\n", - "\n", - "dead_thresh_l2 = 0.5 # absolute L2 threshold for \"dead\" neurons\n", - "\n", - "# palette (one color per tracked freq)\n", - "if 'palette' not in locals() or len(palette) != len(tracked_freqs):\n", - " palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - "\n", - "# ---------- helpers ----------\n", - "def squareish_grid(n):\n", - " c = int(np.ceil(np.sqrt(n)))\n", - " r = int(np.ceil(n / c))\n", - " return r, c\n", - "\n", - "def tracked_power_from_fft2(P, kx, ky, p1, p2):\n", - " i0, j0 = kx % p1, ky % p2\n", - " i1, j1 = (-kx) % p1, (-ky) % p2\n", - " return P[i0, j0] if (i0, j0) == (i1, j1) else P[i0, j0] + P[i1, j1]\n", - "\n", - "def canonical_bin(kx, ky, p1, p2):\n", - " i, j = kx % p1, ky % p2\n", - " if j > p2 - j: # prefer upper half-plane in ky\n", - " i, j = (-kx) % p1, (-ky) % p2\n", - " if j == 0 or (p2 % 2 == 0 and j == p2 // 2): # ky edge: fold kx too\n", - " if i > p1 - i:\n", - " i = (-kx) % p1\n", - " return i, j\n", - "\n", - "def phase_at_2d_bin(m2d, kx, ky, p1, p2):\n", - " F = np.fft.fft2(m2d)\n", - " i, j = canonical_bin(kx, ky, p1, p2)\n", - " return float(np.angle(F[i, j])) # [-pi, pi]\n", - "\n", - "# ---------- final step only ----------\n", - "sd = param_hist_2d[-1]\n", - "Wo = sd[\"W_out\"].detach().cpu().numpy() # (p1*p2, H) or (H, p1*p2)\n", - "if Wo.shape[0] == p1 * p2:\n", - " W = Wo.T\n", - "elif Wo.shape[1] == p1 * p2:\n", - " W = Wo\n", - "else:\n", - " raise ValueError(f\"W_o has incompatible shape {Wo.shape}; expected one dim == p1*p2={p1*p2}\")\n", - "H, D = W.shape\n", - "assert D == p1 * p2\n", - "\n", - "# per-neuron dominant tracked freq, phase, radius\n", - "dom_idx = np.full(H, -1, dtype=int)\n", - "phase = np.full(H, np.nan, dtype=float)\n", - "radius = np.linalg.norm(W, axis=1)\n", - "dead = radius < dead_thresh_l2\n", - "\n", - "for i in range(H):\n", - " if dead[i]:\n", - " continue\n", - " m = W[i].reshape(p1, p2)\n", - " F = np.fft.fft2(m); P = (F.conj() * F).real\n", - " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n", - " j = int(np.argmax(tp))\n", - " kx, ky = tracked_freqs[j]\n", - " if (kx, ky) == (0, 0): # DC has undefined phase \u2192 skip\n", - " continue\n", - " dom_idx[i] = j\n", - " phase[i] = phase_at_2d_bin(m, kx, ky, p1, p2)\n", - "\n", - "# ---------- build module masks (exclude empty and DEAD) ----------\n", - "valid = (~dead) & np.isfinite(phase)\n", - "modules = [(j, (valid & (dom_idx == j))) for j in range(len(tracked_freqs))]\n", - "modules = [(j, msk) for (j, msk) in modules if np.any(msk)] # keep non-empty modules only\n", - "\n", - "if len(modules) == 0:\n", - " print(\"No non-dead, non-DC neurons found for any tracked module.\")\n", - "else:\n", - " # global radial limit for comparability across subplots\n", - " rmax_global = max(float(radius[msk].max()) for _, msk in modules)\n", - "\n", - " R, C = squareish_grid(len(modules))\n", - " fig, axes = plt.subplots(R, C, figsize=(3.8*C + 3.0, 4.2*R), subplot_kw={\"projection\": \"polar\"})\n", - " axes = np.atleast_1d(axes).ravel()\n", - "\n", - " for ax, (j, msk) in zip(axes, modules):\n", - " th = phase[msk]\n", - " r = radius[msk]\n", - "\n", - " # sticks + markers\n", - " for th_i, r_i in zip(th, r):\n", - " ax.plot([th_i, th_i], [0.0, r_i], color=palette[j], linewidth=1.8, alpha=0.9)\n", - " ax.scatter(th, r, s=70, color=palette[j], edgecolors='none', alpha=0.95)\n", - "\n", - " # styling\n", - " ax.set_title(f\"{tracked_freqs[j]} | n={np.count_nonzero(msk)}\", fontsize=11, pad=8)\n", - " ax.set_theta_zero_location(\"E\")\n", - " ax.set_thetagrids(\n", - " np.arange(0, 360, 90), # sparser grid to reduce clutter\n", - " [\"0\", r\"$\\pi/2$\", r\"$\\pi$\", r\"$3\\pi/2$\"]\n", - " )\n", - " ax.set_yticklabels([])\n", - " ax.spines[\"polar\"].set_linewidth(2)\n", - " ax.set_rlim(0, 1.05 * rmax_global if rmax_global > 0 else 1.0)\n", - "\n", - " # hide any unused axes\n", - " for ax in axes[len(modules):]:\n", - " ax.set_visible(False)\n", - "\n", - " # legend with only shown modules, outside on the right\n", - " handles = [Line2D([0],[0], marker='o', linestyle='', color=palette[j],\n", - " label=f\"{tracked_freqs[j]}\", markersize=7)\n", - " for (j, _) in modules]\n", - " fig.legend(handles=handles, loc='center left', bbox_to_anchor=(1.02, 0.5),\n", - " frameon=True, title=\"module (kx, ky)\")\n", - "\n", - " fig.subplots_adjust(right=0.82, wspace=0.35, hspace=0.35)\n", - " # plt.savefig(\"phase_polar_modules_grid.png\", bbox_inches=\"tight\", dpi=170)\n", - " plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "d0a2b283-bb8b-4145-9519-3af401532cc4" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Activation Plot" - ], - "id": "97cdde1e-d4ef-4437-bc1e-5584ab50fd64" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "paths_xy = sequence_to_paths_xy(sequence_xy, p1, p2)\n", - "\n", - "# ========= Config =========\n", - "dead_l2_thresh = 1.0 # absolute L2 threshold (on W_o rows) for DEAD\n", - "if 'palette' not in locals() or len(palette) != len(tracked_freqs):\n", - " palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - "dead_color = (0.6, 0.6, 0.6, 1.0)\n", - "\n", - "tile_w, tile_h = 1.9, 1.9 # inches per neuron tile\n", - "border_lw = 2.0\n", - "title_fs = 14\n", - "cmap_heat = \"viridis\"\n", - "max_neurons_per_module = 9 # 3x3 grid\n", - "\n", - "# ========= Unroll hidden activations (bias-free quadratic RNN) =========\n", - "@torch.no_grad()\n", - "def unroll_hidden_batch_quadratic_no_bias(x_seq_t, rnn, batch_size=4096):\n", - " \"\"\"\n", - " For QuadraticRNN:\n", - " h0 = template @ W_in^T\n", - " h_t = (W_mix h_{t-1} + W_drive x_t)^2 for t = 0..k-1\n", - " Returns: H_all of shape (N, k, d) with the hidden AFTER each update t.\n", - " \"\"\"\n", - " N, K, p = x_seq_t.shape\n", - " d = rnn.W_mix.shape[0]\n", - " H_all = torch.empty((N, K, d), device=x_seq_t.device, dtype=rnn.W_mix.dtype)\n", - "\n", - " # template -> (p,) then h0_vec -> (d,)\n", - " tmpl = rnn.template\n", - " if tmpl.dim() == 2 and tmpl.size(0) == 1:\n", - " tmpl = tmpl.squeeze(0)\n", - " h0_vec = tmpl @ rnn.W_in.T # (d,)\n", - "\n", - " for start in range(0, N, batch_size):\n", - " end = min(N, start + batch_size)\n", - " xb = x_seq_t[start:end] # (B, K, p)\n", - " B = xb.size(0)\n", - "\n", - " # expand h0 across the batch\n", - " h = h0_vec.expand(B, -1).contiguous() # (B, d)\n", - "\n", - " for t in range(0, K):\n", - " xt = xb[:, t, :] # (B, p)\n", - " pre = (h @ rnn.W_mix.T) + (xt @ rnn.W_drive.T)\n", - " h = pre.pow(2)\n", - " H_all[start:end, t, :] = h\n", - "\n", - " return H_all\n", - "\n", - "\n", - "# ========= Collect activations =========\n", - "if 'device' not in locals():\n", - " device = next(rnn_2d.parameters()).device\n", - "X_seq_t = torch.tensor(X_seq_2d, dtype=torch.float32, device=device)\n", - "rnn_2d.to(device).eval()\n", - "with torch.no_grad():\n", - " H_all = unroll_hidden_batch_quadratic_no_bias(X_seq_t, rnn_2d, batch_size=4096) # (N, k, H)\n", - "\n", - "# ========= Prep numpy views =========\n", - "H_all_np = H_all.detach().cpu().numpy() # (N, k, H)\n", - "pos_xy = paths_xy.reshape(-1, 2) # (N*k, 2)\n", - "pos_lin = pos_xy[:, 0] * p2 + pos_xy[:, 1] # linearized index in [0..p1*p2-1]\n", - "acts_flat = H_all_np.reshape(-1, H_all_np.shape[2]) # (N*k, H)\n", - "\n", - "# Precompute counts per 2D position (shared by all neurons)\n", - "counts_lin = np.bincount(pos_lin, minlength=p1*p2).astype(np.int64)\n", - "counts = counts_lin.reshape(p1, p2)\n", - "\n", - "# ========= Module assignment from W_o using tracked 2D frequencies =========\n", - "Wo = rnn_2d.W_out.detach().cpu().numpy() # (p1*p2, H)\n", - "W = Wo.T # (H, p1*p2)\n", - "Hdim = W.shape[0]\n", - "l2_norm = np.linalg.norm(W, axis=1)\n", - "is_dead = (l2_norm < dead_l2_thresh)\n", - "\n", - "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n", - " i0, j0 = kx % p1, ky % p2\n", - " i1, j1 = (-kx) % p1, (-ky) % p2\n", - " if (i0, j0) == (i1, j1):\n", - " return power2d[i0, j0]\n", - " return power2d[i0, j0] + power2d[i1, j1]\n", - "\n", - "dom_idx = np.empty(Hdim, dtype=int) # index into tracked_freqs\n", - "dom_pw = np.empty(Hdim, dtype=float)\n", - "phase = np.empty(Hdim, dtype=float) # phase at representative (i0, j0)\n", - "\n", - "for j in range(Hdim):\n", - " m = W[j].reshape(p1, p2)\n", - " F = np.fft.fft2(m)\n", - " P = (F.conj() * F).real\n", - " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n", - " jj = int(np.argmax(tp))\n", - " dom_idx[j] = jj\n", - " # phase at a consistent representative bin\n", - " i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2\n", - " phase[j] = np.angle(F[i0, j0])\n", - " dom_pw[j] = tp[jj]\n", - "\n", - "# Assign module id (DEAD = -1)\n", - "module_id = np.where(is_dead, -1, dom_idx)\n", - "\n", - "# Group neurons by module (dead last)\n", - "groups = {}\n", - "for nid, mid in enumerate(module_id):\n", - " groups.setdefault(int(mid), []).append(nid)\n", - "\n", - "freq_keys = [i for i in range(len(tracked_freqs)) if i in groups] # keep tracked order\n", - "ordered_mods = freq_keys + ([-1] if -1 in groups else [])\n", - "\n", - "# ========= Pick up to 9 neurons/module (stable) =========\n", - "def pick_neurons_for_module(m, max_neurons=9):\n", - " idxs = groups[m]\n", - " if m == -1:\n", - " # DEAD: weakest first\n", - " idxs = sorted(idxs, key=lambda i: l2_norm[i])\n", - " else:\n", - " # sort by (phase mod 2\u03c0, then -dom_power) for reproducibility\n", - " idxs = sorted(idxs, key=lambda i: ((phase[i] + 2*np.pi) % (2*np.pi), -dom_pw[i]))\n", - " return idxs[:min(max_neurons, len(idxs))]\n", - "\n", - "picked = {m: pick_neurons_for_module(m, max_neurons_per_module) for m in ordered_mods}\n", - "\n", - "# ========= Build mean activation map for a neuron =========\n", - "# Efficient: single bincount per neuron; reuse global counts\n", - "def mean_activation_map_for_neuron(nid: int) -> np.ndarray:\n", - " sums_lin = np.bincount(pos_lin, weights=acts_flat[:, nid], minlength=p1*p2)\n", - " sums = sums_lin.reshape(p1, p2)\n", - " return sums / np.maximum(counts, 1) # avoid div-by-zero\n", - "\n", - "# ========= Plot: one 3\u00d73 figure per module (colorbar on far right) =========\n", - "for m in ordered_mods:\n", - " nids = picked[m]\n", - " if len(nids) == 0:\n", - " continue\n", - "\n", - " mlabel = (\"DEAD\" if m == -1 else tracked_freqs[m])\n", - " mcolor = (dead_color if m == -1 else palette[m])\n", - " safe = str(mlabel).replace(\"(\", \"\").replace(\")\", \"\").replace(\",\", \"_\").replace(\" \", \"\")\n", - "\n", - " # Compute maps for this module\n", - " maps = [mean_activation_map_for_neuron(nid) for nid in nids]\n", - "\n", - " # ---- LOG SCALE: choose vmin as smallest positive, vmax as max ----\n", - " vmax = max(float(mp.max()) for mp in maps)\n", - " posmins = [float(mp[mp > 0].min()) for mp in maps if np.any(mp > 0)]\n", - " use_log = (len(posmins) > 0) and (vmax > 0)\n", - " if use_log:\n", - " vmin_pos = min(posmins)\n", - " norm = LogNorm(vmin=vmin_pos, vmax=vmax)\n", - " else:\n", - " # fallback to linear if no positive values\n", - " vmin = min(float(mp.min()) for mp in maps)\n", - " norm = None # linear\n", - "\n", - " nrows, ncols = 3, 3\n", - " fig = plt.figure(figsize=(ncols * tile_w + 1.2, nrows * tile_h), constrained_layout=True)\n", - " gs = fig.add_gridspec(nrows=nrows, ncols=ncols + 1, width_ratios=[1, 1, 1, 0.06], wspace=0.08, hspace=0.08)\n", - "\n", - " axes_r, last_im = [], None\n", - " for r in range(nrows):\n", - " for c in range(ncols):\n", - " ax = fig.add_subplot(gs[r, c]); axes_r.append(ax)\n", - "\n", - " for i, ax in enumerate(axes_r):\n", - " if i < len(nids):\n", - " if use_log:\n", - " im = ax.imshow(maps[i], norm=norm, origin=\"lower\", aspect=\"equal\", cmap=cmap_heat)\n", - " # Optional: make zeros clearly visible\n", - " # im.cmap.set_under('white')\n", - " else:\n", - " im = ax.imshow(maps[i], vmin=vmin, vmax=vmax, origin=\"lower\", aspect=\"equal\", cmap=cmap_heat)\n", - " last_im = im\n", - " ax.set_title(f\"h={nids[i]}\", fontsize=10, color=mcolor)\n", - " for sp in ax.spines.values():\n", - " sp.set_edgecolor(mcolor); sp.set_linewidth(border_lw)\n", - " else:\n", - " ax.axis(\"off\")\n", - " ax.set_xticks([]); ax.set_yticks([])\n", - "\n", - " # Colorbar on the far right\n", - " cax = fig.add_subplot(gs[:, -1])\n", - " if last_im is not None:\n", - " cbar = fig.colorbar(last_im, cax=cax)\n", - " cbar.set_label(\"Mean activation (log scale)\" if use_log else \"Mean activation\", fontsize=11)\n", - " # Optional: nice log ticks\n", - " # if use_log:\n", - " # import numpy as np\n", - " # ticks = np.logspace(np.log10(vmin_pos), np.log10(vmax), num=5)\n", - " # cbar.set_ticks(ticks)\n", - "\n", - " fig.suptitle(\n", - " f\"Module {mlabel} | p1={p1}, p2={p2} (showing {len(nids)} of {len(groups[m])} neurons)\",\n", - " fontsize=title_fs, color=mcolor\n", - " )\n", - " fig.supxlabel(\"position y\", fontsize=12)\n", - " fig.supylabel(\"position x\", fontsize=12)\n", - "\n", - " plt.savefig(f\"activations2D_module_{safe}_3x3.png\", bbox_inches=\"tight\", dpi=170)\n", - " plt.show()" - ], - "execution_count": null, - "outputs": [], - "id": "667c8114-4a0b-480e-9ba2-022122926a8d" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [], - "execution_count": null, - "outputs": [], - "id": "be86d544-02bf-4623-b57b-15b16f3248a5" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [], - "execution_count": null, - "outputs": [], - "id": "b8cbe5d1-9668-405f-b452-ab92dcad91ef" - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/notebooks/C_n.ipynb b/notebooks/C_n.ipynb deleted file mode 100644 index 82a1833..0000000 --- a/notebooks/C_n.ipynb +++ /dev/null @@ -1,804 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51d11caf-0971-4324-b63b-819b714a9c3c", - "metadata": {}, - "source": [ - "# Binary Group Composition with $C_n$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random\n", - "import torch\n", - "import torch.nn as nn\n", - "from tqdm import tqdm\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "from matplotlib.animation import FuncAnimation\n", - "from matplotlib.ticker import FormatStrFormatter\n", - "from matplotlib.ticker import FuncFormatter\n", - "from matplotlib.ticker import MaxNLocator" - ] - }, - { - "cell_type": "markdown", - "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88", - "metadata": {}, - "outputs": [], - "source": [ - "def one_hot(p):\n", - " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n", - " vec = np.zeros(p)\n", - " vec[0] = 1\n", - " return vec\n", - "\n", - "def generate_template(p, magnitude, exponent):\n", - " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n", - " template = np.ones(p) # Base term (DC component)\n", - " for freq in range(1, p):\n", - " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n", - " return template / p\n", - "\n", - "def generate_fixed_template(p):\n", - " # Generate template array from Fourier spectrum\n", - " spectrum = np.zeros(p, dtype=complex)\n", - " \n", - " # Set only three frequencies with specific amplitudes\n", - " spectrum[1] = 12.5 # Positive frequency\n", - " spectrum[-1] = 12.5 # Negative frequency (conjugate)\n", - " spectrum[2] = 10 # Positive frequency\n", - " spectrum[-2] = 10 # Negative frequency (conjugate)\n", - " spectrum[3] = 7.5 # Second frequency\n", - " spectrum[-3] = 7.5 # Its conjugate\n", - " spectrum[4] = 5 # Second frequency\n", - " spectrum[-4] = 5 # Its conjugate\n", - " spectrum[5] = 2.5 # Third frequency \n", - " spectrum[-5] = 2.5 # Its conjugate\n", - " \n", - " # Generate signal from spectrum\n", - " template = np.fft.ifft(spectrum).real\n", - "\n", - " return template\n", - "\n", - "def ModularAdditionDataset(p, template):\n", - " # Initialize data arrays\n", - " X = np.zeros((p * p, 2, p)) # Shape: (p^2, 2, p)\n", - " Y = np.zeros((p * p, p)) # Shape: (p^2, p)\n", - " \n", - " # Generate the dataset\n", - " idx = 0\n", - " for a in range(p):\n", - " for b in range(p):\n", - " q = (a + b) % p # a + b mod p\n", - " X[idx, 0, :] = np.roll(template, a)\n", - " X[idx, 1, :] = np.roll(template, b)\n", - " Y[idx, :] = np.roll(template, q)\n", - " idx += 1\n", - " \n", - " return X, Y" - ] - }, - { - "cell_type": "markdown", - "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063", - "metadata": {}, - "source": [ - "## Architecture" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a", - "metadata": {}, - "outputs": [], - "source": [ - "class TwoLayerNet(nn.Module):\n", - " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n", - " super(TwoLayerNet, self).__init__()\n", - " \n", - " # Store dimensions\n", - " self.p = p\n", - " self.hidden_size = hidden_size\n", - " self.nonlinearity = nonlinearity\n", - " self.init_scale = init_scale\n", - " self.output_scale = output_scale\n", - " \n", - " # Initialize parameters \n", - " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n", - " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n", - " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n", - "\n", - " def forward(self, x):\n", - " \n", - " # First layer (linear and combined)\n", - " x1 = x[:, :self.p] @ self.U.T\n", - " x2 = x[:, self.p:] @ self.V.T\n", - " x_combined = x1 + x2\n", - "\n", - " # Apply nonlinearity activation\n", - " if self.nonlinearity == 'relu':\n", - " x_combined = torch.relu(x_combined)\n", - " elif self.nonlinearity == 'square':\n", - " x_combined = x_combined**2\n", - " elif self.nonlinearity == 'linear':\n", - " x_combined = x_combined\n", - " elif self.nonlinearity == 'tanh':\n", - " x_combined = torch.tanh(x_combined)\n", - " elif self.nonlinearity == 'gelu':\n", - " gelu = torch.nn.GELU()\n", - " x_combined = gelu(x_combined)\n", - " else:\n", - " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n", - "\n", - " # Second layer (linear)\n", - " x_out = x_combined @ self.W\n", - "\n", - " # Feature learning scaling\n", - " x_out *= self.output_scale\n", - " \n", - " return x_out" - ] - }, - { - "cell_type": "markdown", - "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168", - "metadata": {}, - "source": [ - "## Optimization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc9ba87b-9607-4a4a-9b00-00c15adb2f5a", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn\n", - "\n", - "class PerNeuronScaledSGD(torch.optim.Optimizer):\n", - " \"\"\"\n", - " Per-neuron scaled SGD optimizer that exploits model homogeneity.\n", - " \n", - " Learning rate scaling per neuron i:\n", - " eta_i = lr * ||theta_i||^(2-degree)\n", - " \n", - " where:\n", - " - theta_i comprises all parameters associated with neuron i\n", - " - degree is the degree of homogeneity of the model\n", - "\n", - " For SequentialMLP with sequence length k:\n", - " - theta_i = (W_in[i, :], W_out[:, i])\n", - " - degree = k+1 (activation is x^k, one more layer for W_out)\n", - "\n", - " For TwoLayerNet:\n", - " - theta_i = (U[i, :], V[i, :], W[i, :])\n", - " - degree:\n", - " * nonlinearity == 'square' -> 3\n", - " * otherwise default 2 (unless explicitly provided)\n", - " \"\"\"\n", - "\n", - " def __init__(self, \n", - " model, \n", - " lr=1.0, \n", - " degree=None\n", - " ) -> None:\n", - " \"\"\"\n", - " Args:\n", - " model: SequentialMLP, TwoLayerNet, or compatible model\n", - " lr: base learning rate\n", - " degree: degree of homogeneity (exponent for norm-based scaling)\n", - " If None, inferred from model:\n", - " - SequentialMLP: uses k+1 where k is sequence length\n", - " - TwoLayerNet:\n", - " * 'square' -> 3\n", - " * otherwise -> 2\n", - " - Default: 2\n", - " \"\"\"\n", - " model_type = type(model).__name__\n", - "\n", - " # Infer degree of homogeneity from model if not provided\n", - " if degree is None:\n", - " if hasattr(model, 'k'): # SequentialMLP-style\n", - " degree = model.k + 1\n", - " elif model_type == 'TwoLayerNet':\n", - " nl = getattr(model, 'nonlinearity', None)\n", - " if nl == 'square':\n", - " degree = 3\n", - " else:\n", - " # For relu/linear/tanh/gelu or unknown, fall back\n", - " degree = 2\n", - " else:\n", - " # Default for quadratic-ish models\n", - " degree = 2\n", - " \n", - " # Get model parameters\n", - " params = list(model.parameters())\n", - " \n", - " super().__init__(\n", - " [{'params': params, 'model': model, 'model_type': model_type}], \n", - " dict(lr=lr, degree=degree)\n", - " )\n", - "\n", - " @torch.no_grad()\n", - " def step(self, closure=None):\n", - " group = self.param_groups[0]\n", - " model = group['model']\n", - " lr = group['lr']\n", - " degree = group['degree']\n", - " model_type = group['model_type']\n", - " \n", - " if model_type == 'SequentialMLP':\n", - " # SequentialMLP: W_in (d, k*p), W_out (p, d)\n", - " W_in = model.W_in\n", - " W_out = model.W_out\n", - " g_in = W_in.grad\n", - " g_out = W_out.grad\n", - " \n", - " if g_in is None or g_out is None:\n", - " return\n", - " \n", - " # Per-neuron norms: theta_i = (W_in[i, :], W_out[:, i])\n", - " u2 = (W_in**2).sum(dim=1) # (d,)\n", - " w2 = (W_out**2).sum(dim=0) # (d,)\n", - " theta_norm = torch.sqrt(u2 + w2 + 1e-12) # (d,)\n", - " \n", - " # Scale = ||theta_i||^(2-degree)\n", - " scale = theta_norm.pow(2 - degree)\n", - " \n", - " # Scale each neuron's gradients\n", - " g_in.mul_(scale.view(-1, 1))\n", - " g_out.mul_(scale.view(1, -1))\n", - " \n", - " # SGD update\n", - " W_in.add_(g_in, alpha=-lr)\n", - " W_out.add_(g_out, alpha=-lr)\n", - "\n", - " elif model_type == 'TwoLayerNet':\n", - " # TwoLayerNet: U (d, p), V (d, p), W (d, p)\n", - " U = model.U\n", - " V = model.V\n", - " W = model.W\n", - "\n", - " g_U = U.grad\n", - " g_V = V.grad\n", - " g_W = W.grad\n", - "\n", - " if g_U is None or g_V is None or g_W is None:\n", - " return\n", - "\n", - " # Per-neuron norms: theta_i = (U[i, :], V[i, :], W[i, :])\n", - " u2 = (U**2).sum(dim=1) # (hidden_size,)\n", - " v2 = (V**2).sum(dim=1) # (hidden_size,)\n", - " w2 = (W**2).sum(dim=1) # (hidden_size,)\n", - " theta_norm = torch.sqrt(u2 + v2 + w2 + 1e-12) # (hidden_size,)\n", - "\n", - " # Scale = ||theta_i||^(2-degree)\n", - " scale = theta_norm.pow(2 - degree) # (hidden_size,)\n", - "\n", - " # Scale each neuron's gradients\n", - " scale_view = scale.view(-1, 1) # (hidden_size, 1)\n", - " g_U.mul_(scale_view)\n", - " g_V.mul_(scale_view)\n", - " g_W.mul_(scale_view)\n", - "\n", - " # SGD update\n", - " U.add_(g_U, alpha=-lr)\n", - " V.add_(g_V, alpha=-lr)\n", - " W.add_(g_W, alpha=-lr)\n", - "\n", - " else:\n", - " raise ValueError(f\"PerNeuronScaledSGD: Unsupported model structure with {model_type}\")\n", - " \n", - " return None\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1035f81c-e877-4655-8640-4e4c3d323af8", - "metadata": {}, - "outputs": [], - "source": [ - "def test_accuracy(model, dataloader):\n", - " correct = 0\n", - " total = 0\n", - " \n", - " with torch.no_grad(): # Disable gradient calculation for evaluation\n", - " for inputs, labels in dataloader:\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - " outputs = model(inputs)\n", - " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n", - " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n", - " correct += (predicted == true_labels).sum().item()\n", - " total += labels.size(0)\n", - " \n", - " accuracy = 100 * correct / total\n", - " return accuracy\n", - "\n", - "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n", - " model.train() # Set the model to training mode\n", - " loss_history = [] # List to store loss values\n", - " accuracy_history = []\n", - " param_history = []\n", - "\n", - " for epoch in range(epochs):\n", - " running_loss = 0.0\n", - " for inputs, labels in dataloader:\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - "\n", - " optimizer.zero_grad() # Zero gradients\n", - " outputs = model(inputs) # Forward pass\n", - " loss = criterion(outputs, labels) # Compute loss\n", - " loss.backward() # Backpropagation\n", - " optimizer.step() # Update weights\n", - "\n", - " running_loss += loss.item()\n", - "\n", - " # Append the average loss for the epoch to loss_history\n", - " avg_loss = running_loss / len(dataloader)\n", - " loss_history.append(avg_loss)\n", - "\n", - " # Append the accuracy\n", - " model.eval()\n", - " accuracy = test_accuracy(model, dataloader)\n", - " accuracy_history.append(accuracy)\n", - " model.train()\n", - "\n", - " # Save current model parameters\n", - " current_params = {\n", - " \"U\": model.U.detach().cpu().clone(),\n", - " \"V\": model.V.detach().cpu().clone(),\n", - " \"W\": model.W.detach().cpu().clone()\n", - " }\n", - " param_history.append(current_params)\n", - "\n", - " # Print verbose information every `verbose_interval` epochs\n", - " if (epoch + 1) % verbose_interval == 0:\n", - " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n", - "\n", - " return loss_history, accuracy_history, param_history # Return loss history for plotting" - ] - }, - { - "cell_type": "markdown", - "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c", - "metadata": {}, - "source": [ - "## Plotting functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3", - "metadata": {}, - "outputs": [], - "source": [ - "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n", - " # Y-axis ticks\n", - " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n", - " \n", - " # X-axis ticks\n", - " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n", - "\n", - " # Scientific notation formatting\n", - " if ax.get_yscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n", - " if ax.get_xscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n", - "\n", - " ax.xaxis.offsetText.set_fontsize(20)\n", - " ax.grid()\n", - "\n", - " # Customize spines\n", - " for spine in [\"top\", \"right\"]:\n", - " ax.spines[spine].set_visible(False)\n", - " for spine in [\"left\", \"bottom\"]:\n", - " ax.spines[spine].set_linewidth(3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a", - "metadata": {}, - "outputs": [], - "source": [ - "def get_power(points):\n", - " p = len(points)\n", - " num_coefficients = (p // 2) + 1\n", - " \n", - " # Perform FFT and calculate power spectrum\n", - " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n", - " power = np.abs(ft[:num_coefficients])**2 / p\n", - " \n", - " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n", - " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n", - " power[1:num_coefficients - 1] *= 2\n", - " else: # p is odd, no Nyquist frequency\n", - " power[1:] *= 2\n", - "\n", - " # Confirm the power sum approximates the squared norm of points\n", - " total_power = np.sum(power)\n", - " norm_squared = np.linalg.norm(points)**2\n", - " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n", - " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n", - "\n", - " return np.arange(num_coefficients), power\n", - "\n", - "def interpolate(ax, points, color, continuous, alpha=1.0):\n", - " p = len(points)\n", - " if continuous:\n", - " # Perform Fourier Transform\n", - " ft = np.fft.fft(points)\n", - " \n", - " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n", - " num_coefficients = (p // 2) + 1\n", - " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n", - " \n", - " # Create a dense set of x-values for smooth interpolation\n", - " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n", - " curr_val = np.zeros(xs.shape, dtype=complex)\n", - " \n", - " # Use only non-negative frequencies for interpolation\n", - " for freq in range(num_coefficients):\n", - " theta = np.angle(ft[freq])\n", - " r = np.abs(ft[freq]) / p\n", - " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n", - " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n", - " r *= 2\n", - " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n", - "\n", - " # Plot the real part (since output is real-valued)\n", - " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n", - " else:\n", - " ax.plot(np.arange(p), points, color=color, alpha=alpha) " - ] - }, - { - "cell_type": "markdown", - "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7", - "metadata": {}, - "source": [ - "## Gradient Descent Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bcd15c5a-5745-41ba-b015-48e403160c7e", - "metadata": {}, - "outputs": [], - "source": [ - "seed = 0 # or any integer you like\n", - "random.seed(seed)\n", - "np.random.seed(seed)\n", - "torch.manual_seed(seed)\n", - "torch.cuda.manual_seed_all(seed) # if using GPU\n", - "\n", - "# TEST_MODE: Reduce p and hidden_size for faster automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "\n", - "# Data Generation using the new function\n", - "p = 11 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n", - "\n", - "# Get base vector\n", - "# template = one_hot(p)\n", - "template = generate_fixed_template(p)\n", - "\n", - "# Mean center template\n", - "template -= np.mean(template)\n", - "\n", - "# Generate dataset using numpy\n", - "X, Y = ModularAdditionDataset(p, template)\n", - "\n", - "# Convert to PyTorch tensors\n", - "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n", - "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n", - "\n", - "# Create a TensorDataset and DataLoader\n", - "dataset = TensorDataset(X_tensor, Y_tensor)\n", - "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n", - "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n", - "\n", - "# Initialize model\n", - "hidden_size = 20 if TEST_MODE else 200 # Reduced in test mode\n", - "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-5, output_scale=1e0)\n", - "\n", - "# Create loss function\n", - "loss = nn.MSELoss()\n", - "\n", - "# Construct optimizer\n", - "lr = 0.01\n", - "optimizer = PerNeuronScaledSGD(model, lr=lr, degree=3) # explicit\n", - "\n", - "# Train the model\n", - "# TEST_MODE: Set to reduce epochs for automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "epochs = 2 if TEST_MODE else 50001\n", - "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//10))" - ] - }, - { - "cell_type": "markdown", - "id": "0f48aebc-a439-405a-a057-3f5c24cca91a", - "metadata": {}, - "source": [ - "## Plot Loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff46febe-abb5-459a-bb06-a18a26afb967", - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", - "ax.plot(list(loss_history), lw=7)\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "for k, alpha in enumerate(alpha_values):\n", - " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n", - "\n", - "ax.set_xscale(\"log\")\n", - "ax.set_yscale(\"log\")\n", - "ax.set_ylim(1e-2, 10)\n", - "ax.set_xlabel('Epochs', fontsize=24)\n", - "ax.set_ylabel('Train Loss', fontsize=24)\n", - "\n", - "style_axes(ax)\n", - "plt.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n", - "plt.savefig(\"loss-without-lines.svg\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "40b851e7-6256-43cd-b9f3-aca38db04917", - "metadata": {}, - "source": [ - "## Power Spectrum of output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68b25ca9-6339-49dd-9d45-577a51798a25", - "metadata": {}, - "outputs": [], - "source": [ - "# === SETTINGS ===\n", - "p = Y_tensor.shape[1]\n", - "num_freqs = p // 2 + 1\n", - "mom = 0.9\n", - "\n", - "# Compute template power spectrum\n", - "template_ft = np.fft.rfft(template)\n", - "template_power = np.abs(template_ft)[:num_freqs]\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "# === Theory lines ===\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "theta0 = np.sqrt(2) * model.init_scale\n", - "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n", - "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n", - "step_size = 2 * coef * lr / (1 - mom)\n", - "\n", - "\n", - "# Color settings\n", - "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n", - "manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - "}\n", - "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n", - "\n", - "# Compute output power over time (GD)\n", - "num_points = 1000\n", - "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n", - "powers_over_time = []\n", - "\n", - "for step in steps:\n", - " model.load_state_dict(param_history[step])\n", - " model.eval()\n", - " with torch.no_grad():\n", - " outputs = model(X_tensor)\n", - " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", - " avg_power = np.mean(np.abs(ft), axis=0)\n", - " powers_over_time.append(avg_power)\n", - "\n", - "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n", - "\n", - "# === PLOTTING ===\n", - "fig, ax = plt.subplots(figsize=(8, 6))\n", - "\n", - "for k in range(num_freqs):\n", - " color = colors[k]\n", - " label = fr\"$\\xi = {k}$\"# if k in [1, 3, 5] else None\n", - " ax.plot(steps, powers_over_time[:, k], color=color, lw=5, label=label)\n", - " label_agf = 'AGF' if k == 10 else None\n", - " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n", - "\n", - "\n", - "# Labeling and formatting\n", - "ax.set_xscale('log')\n", - "ax.set_ylabel(\"Power\", fontsize=24)\n", - "ax.set_xlabel(\"Epochs\", fontsize=24)\n", - "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper left', labelspacing=0.25)\n", - "style_axes(ax)\n", - "ax.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n", - "plt.savefig(\"fourier_power_only.svg\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd", - "metadata": {}, - "source": [ - "## Plot outputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e333d1ab-1501-434f-86d2-82c10bb58f11", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "\n", - "# -------------------------------------------------------------------\n", - "# Choose the template (example: first row of Y_tensor)\n", - "# -------------------------------------------------------------------\n", - "template = Y_tensor[0].detach().cpu().numpy() # shape (p,)\n", - "p = template.shape[0]\n", - "x = np.arange(p)\n", - "\n", - "# -------------------------------------------------------------------\n", - "# Figure 1: Template as black bar plot\n", - "# -------------------------------------------------------------------\n", - "fig1, ax1 = plt.subplots(figsize=(8, 6))\n", - "ax1.bar(x, template, color=\"black\")\n", - "\n", - "ax1.set_xlabel(\"Index\", fontsize=14)\n", - "ax1.set_ylabel(\"Template value\", fontsize=14)\n", - "style_axes(ax1)\n", - "ax1.grid(False)\n", - "\n", - "plt.tight_layout()\n", - "fig1.savefig(\"template_bar.pdf\", bbox_inches=\"tight\")\n", - "\n", - "# -------------------------------------------------------------------\n", - "# Figure 2: Fourier magnitude bar plot with conjugate-pair coloring\n", - "# -------------------------------------------------------------------\n", - "# Compute Fourier transform and magnitude\n", - "fft_template = np.fft.fft(template)\n", - "fft_mag = np.abs(fft_template)\n", - "freqs = np.arange(p)\n", - "\n", - "# Number of *frequency groups* accounting for conjugate symmetry:\n", - "# groups: 0, 1, ..., floor(p/2)\n", - "num_groups = p // 2 + 1\n", - "\n", - "# Color settings\n", - "cmap = plt.colormaps.get_cmap('tab20').resampled(num_groups)\n", - "manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - "}\n", - "group_colors = [manual_colors.get(i, cmap(i)) for i in range(num_groups)]\n", - "\n", - "# Assign each k a color based on its conjugate-symmetry group\n", - "bar_colors = []\n", - "for k in range(p):\n", - " # group index: k and p-k share the same group\n", - " g = k if k <= p // 2 else p - k\n", - " bar_colors.append(group_colors[g])\n", - "\n", - "fig2, ax2 = plt.subplots(figsize=(8, 6))\n", - "ax2.bar(freqs, fft_mag, color=bar_colors)\n", - "\n", - "ax2.set_xlabel(\"Frequency index $k$\", fontsize=14)\n", - "ax2.set_ylabel(r\"$|\\hat{t}[k]|$\", fontsize=14)\n", - "style_axes(ax2)\n", - "ax2.grid(False)\n", - "\n", - "plt.tight_layout()\n", - "fig2.savefig(\"template_fft_bar.pdf\", bbox_inches=\"tight\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4ecb1adb-66ed-44a1-9b6a-ef5dcc6dbe11", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/cn.ipynb b/notebooks/cn.ipynb new file mode 100644 index 0000000..fa3bf65 --- /dev/null +++ b/notebooks/cn.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Binary Group Composition on $C_n$\n", + "\n", + "**Group:** Cyclic group $C_n$ of order $p$ (i.e., modular addition mod $p$). \n", + "**Task:** Given encodings of two group elements $a, b \\in C_p$, predict the encoding of their product $a + b \\pmod{p}$. \n", + "**Sequence length:** $k = 2$ (binary composition). \n", + "**Architecture:** `TwoLayerNet` with square nonlinearity. \n", + "**Key result:** The network learns one Fourier mode at a time, producing a staircase in the training loss." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "import src.dataset as dataset\n", + "import src.model as model\n", + "import src.optimizer as optimizer\n", + "import src.power as power\n", + "import src.template as template\n", + "import src.train as train_mod\n", + "import src.viz as viz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", + "\n", + "seed = 0\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "p = 11\n", + "hidden_size = 20 if TEST_MODE else 200\n", + "epochs = 2 if TEST_MODE else 5000\n", + "lr = 0.01\n", + "init_scale = 1e-5\n", + "\n", + "FIGURES_DIR = \"figures\"\n", + "os.makedirs(FIGURES_DIR, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "067e5bc3", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a fixed Cn template with known Fourier structure\n", + "tpl = template.fixed_cn(\n", + " group_size=p,\n", + " fourier_coef_mags=[0, 12.5, 10, 7.5, 5, 2.5],\n", + ")\n", + "\n", + "# Mean-center the template\n", + "tpl = tpl - np.mean(tpl)\n", + "\n", + "# Build exhaustive dataset: all p^2 pairs\n", + "X, Y = dataset.cn_dataset(tpl)\n", + "\n", + "# Move to tensors and flatten\n", + "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n", + "\n", + "ds = TensorDataset(X_tensor, Y_tensor)\n", + "dataloader = DataLoader(ds, batch_size=len(ds), shuffle=False)\n", + "\n", + "print(f\"Group: C_{p}, order {p}\")\n", + "print(f\"Dataset: {len(ds)} samples (all {p}x{p} pairs)\")\n", + "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a26b21e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize template and its power spectrum\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n", + "\n", + "ax1.bar(range(p), tpl, color=\"black\")\n", + "ax1.set_xlabel(\"Group element\")\n", + "ax1.set_ylabel(\"Template value\")\n", + "ax1.set_title(f\"Template $t$ on $C_{{{p}}}$\")\n", + "\n", + "pwr, freqs = power.get_power_1d(tpl)\n", + "ax2.bar(freqs, pwr, color=\"steelblue\")\n", + "ax2.set_xlabel(\"Frequency\")\n", + "ax2.set_ylabel(\"Power\")\n", + "ax2.set_title(\"Power spectrum of template\")\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cn_template.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model and Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "230b324e", + "metadata": {}, + "outputs": [], + "source": [ + "net = model.TwoLayerNet(\n", + " group_size=p,\n", + " hidden_size=hidden_size,\n", + " nonlinearity=\"square\",\n", + " init_scale=init_scale,\n", + ")\n", + "net = net.to(device)\n", + "\n", + "criterion = nn.MSELoss()\n", + "opt = optimizer.PerNeuronScaledSGD(net, lr=lr, degree=3)\n", + "\n", + "print(f\"Model: TwoLayerNet(p={p}, hidden={hidden_size}, init_scale={init_scale})\")\n", + "print(f\"Optimizer: PerNeuronScaledSGD(lr={lr}, degree=3)\")\n", + "print(f\"Training for {epochs} epochs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n", + " net,\n", + " dataloader,\n", + " criterion,\n", + " opt,\n", + " epochs=epochs,\n", + " verbose_interval=max(1, epochs // 10),\n", + " save_param_interval=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute theoretical loss plateau levels\n", + "theory = power.theoretical_loss_levels_1d(tpl)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.plot(loss_history, lw=4)\n", + "\n", + "for level in theory[\"levels\"]:\n", + " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_ylabel(\"Train Loss\", fontsize=18)\n", + "ax.set_title(f\"Training loss on $C_{{{p}}}$\", fontsize=20)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cn_loss.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Power Spectrum Over Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute template power for reference lines\n", + "num_freqs = p // 2 + 1\n", + "template_ft = np.fft.rfft(tpl)\n", + "template_power = np.abs(template_ft)[:num_freqs]\n", + "\n", + "# Compute output power over time\n", + "num_points = min(500, len(param_history))\n", + "steps = np.unique(np.logspace(0, np.log10(max(1, len(param_history) - 1)), num_points, dtype=int))\n", + "powers_over_time = []\n", + "\n", + "for step in steps:\n", + " net.load_state_dict(param_history[step])\n", + " net.eval()\n", + " with torch.no_grad():\n", + " outputs = net(X_tensor)\n", + " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", + " avg_power = np.mean(np.abs(ft), axis=0)\n", + " powers_over_time.append(avg_power)\n", + "\n", + "powers_over_time = np.array(powers_over_time)\n", + "\n", + "# Plot\n", + "colors = [\"tab:blue\", \"tab:orange\", \"tab:red\", \"tab:green\", \"tab:brown\", \"tab:purple\"]\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "for k in range(num_freqs):\n", + " color = colors[k] if k < len(colors) else f\"C{k}\"\n", + " ax.plot(steps, powers_over_time[:, k], color=color, lw=4, label=rf\"$\\xi = {k}$\")\n", + " ax.axhline(template_power[k], color=color, linestyle=\"dotted\", linewidth=2, alpha=0.5, zorder=-10)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_ylabel(\"Power\", fontsize=18)\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_title(f\"Power spectrum over training on $C_{{{p}}}$\", fontsize=20)\n", + "ax.legend(fontsize=12, title=\"Frequency\", title_fontsize=14, loc=\"upper left\", labelspacing=0.25)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cn_power_spectrum.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## AGF Numerics\n", + "\n", + "Compare gradient descent training with the Alternating Gradient Flow (AGF) approximation.\n", + "AGF decomposes training into alternating utility-maximization and cost-minimization phases,\n", + "predicting when each Fourier mode activates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "\n", + "class ModsumSubNetwork(nn.Module):\n", + " \"\"\"A single neuron of the two-layer network for AGF analysis.\"\"\"\n", + "\n", + " def __init__(self, d_in, d_out, init_scale):\n", + " super().__init__()\n", + " assert d_in % 2 == 0\n", + " self.p = d_in // 2\n", + " self.u = nn.Linear(self.p, 1, bias=False)\n", + " self.v = nn.Linear(self.p, 1, bias=False)\n", + " self.w = nn.Linear(1, d_out, bias=False)\n", + " with torch.no_grad():\n", + " self.w.weight.mul_(init_scale)\n", + " self.u.weight.mul_(init_scale)\n", + " self.v.weight.mul_(init_scale)\n", + " self.active = False\n", + " self.util_acc = 0\n", + " self.c_a = 1 / self.get_norm() - 1\n", + " self.normalize()\n", + "\n", + " def get_norm(self):\n", + " sqnorm = lambda x: torch.linalg.norm(x.weight) ** 2\n", + " return torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n", + "\n", + " def reinitialize(self, u, v, w):\n", + " with torch.no_grad():\n", + " self.u.weight.copy_(u)\n", + " self.v.weight.copy_(v)\n", + " self.w.weight.copy_(w)\n", + " self.c_a = 1 / self.get_norm() - 1\n", + "\n", + " def forward(self, x):\n", + " x1 = x[:, : self.p]\n", + " x2 = x[:, self.p :]\n", + " return self.w((self.u(x1) + self.v(x2)) ** 2)\n", + "\n", + " def normalize(self):\n", + " norm = self.get_norm()\n", + " with torch.no_grad():\n", + " self.w.weight.div_(norm)\n", + " self.u.weight.div_(norm)\n", + " self.v.weight.div_(norm)\n", + "\n", + " def utility_step(self, x, residual, learning_rate):\n", + " f_i = self(x)\n", + " util = torch.einsum(\"nd,nd->n\", f_i, residual).mean()\n", + " self.util_acc += 3 * learning_rate * util.item()\n", + " norm_th = 1 / (1 + self.c_a - self.util_acc)\n", + " util.backward()\n", + " with torch.no_grad():\n", + " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n", + " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n", + " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n", + " self.w.weight.grad.zero_()\n", + " self.u.weight.grad.zero_()\n", + " self.v.weight.grad.zero_()\n", + " self.normalize()\n", + "\n", + "\n", + "class ModsumNetwork(nn.Module):\n", + " \"\"\"Network of ModsumSubNetwork neurons for AGF simulation.\"\"\"\n", + "\n", + " def __init__(self, d_in, d_out, init_scale, width=100):\n", + " super().__init__()\n", + " self.d_in = d_in\n", + " self.d_out = d_out\n", + " self.width = width\n", + " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n", + " self.neurons = nn.ModuleList(neurons)\n", + " self.set_mode(\"utilmax\")\n", + "\n", + " def load_init(self, U, V, W):\n", + " for i, n in enumerate(self.neurons):\n", + " u, v, w = U[i], V[i], W[i][:, None]\n", + " n.reinitialize(u, v, w)\n", + "\n", + " def dormant(self):\n", + " return [neuron for neuron in self.neurons if not neuron.active]\n", + "\n", + " def set_mode(self, mode):\n", + " if mode not in [\"utilmax\", \"costmin\"]:\n", + " raise ValueError(\"mode must be utilmax or costmin\")\n", + " self.mode = mode\n", + " for neuron in self.neurons:\n", + " grad_on = (mode == \"utilmax\") ^ neuron.active\n", + " for param in neuron.parameters():\n", + " param.requires_grad = grad_on\n", + "\n", + " def forward(self, x):\n", + " active = [n for n in self.neurons if n.active]\n", + " if not active:\n", + " return torch.zeros(x.shape[0], self.d_out)\n", + " outputs = torch.stack([n(x) for n in active], dim=0)\n", + " return torch.sum(outputs, dim=0)\n", + "\n", + "\n", + "def train_agf(\n", + " X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None,\n", + " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4,\n", + "):\n", + " \"\"\"Run the Alternating Gradient Flow (AGF) approximation.\"\"\"\n", + " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n", + " if from_init:\n", + " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n", + " width = U.shape[0]\n", + " agf_net = ModsumNetwork(d_in, d_out, init_sz, width=width)\n", + " agf_net.load_init(U, V, W)\n", + " else:\n", + " agf_net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)\n", + " X_train.requires_grad = False\n", + "\n", + " results = {\"t\": [], \"losses\": [], \"pred\": []}\n", + "\n", + " def update_results(t):\n", + " results[\"t\"].append(t)\n", + " residual = (Y_train - agf_net(X_train)).detach()\n", + " results[\"losses\"].append((residual**2).mean().item())\n", + " results[\"pred\"].append(agf_net(X_train).detach().cpu().clone())\n", + "\n", + " t = 0\n", + " update_results(t)\n", + " for _ in tqdm(range(agf_steps)):\n", + " residual = (1 / d_out) * 2 * (Y_train - agf_net(X_train))\n", + " residual = residual.detach()\n", + " iters = 0\n", + " mode = \"utilmax\"\n", + " while mode == \"utilmax\":\n", + " for n in agf_net.neurons:\n", + " if n.active:\n", + " continue\n", + " n.utility_step(X_train, residual, utilmax_lr)\n", + " if n.util_acc > n.c_a:\n", + " n.active = True\n", + " mode = \"costmin\"\n", + " iters += 1\n", + " agf_net.set_mode(mode)\n", + " t += iters\n", + "\n", + " agf_opt = torch.optim.SGD(agf_net.parameters(), lr=costmin_lr, momentum=0.9)\n", + " for _ in range(int(costmin_maxiter)):\n", + " agf_opt.zero_grad(set_to_none=False)\n", + " residual = Y_train - agf_net(X_train)\n", + " loss = (residual**2).mean()\n", + " loss.backward()\n", + " agf_opt.step()\n", + " agf_net.set_mode(\"utilmax\")\n", + "\n", + " print(f\"loss: {loss.item():.5f}\")\n", + " update_results(t)\n", + "\n", + " if not agf_net.dormant() or loss.item() < loss_thresh:\n", + " break\n", + "\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not TEST_MODE:\n", + " agf_results = train_agf(\n", + " X_tensor, Y_tensor,\n", + " init_sz=init_scale,\n", + " agf_steps=50,\n", + " from_init=param_history[0],\n", + " utilmax_lr=0.1,\n", + " costmin_lr=0.01,\n", + " costmin_maxiter=1e4,\n", + " loss_thresh=1e-4,\n", + " )\n", + "else:\n", + " agf_results = None\n", + " print(\"Skipping AGF in TEST_MODE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loss: GD vs AGF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.plot(loss_history, lw=4, label=\"GD\")\n", + "\n", + "# Theory plateau levels\n", + "for level in theory[\"levels\"]:\n", + " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", + "\n", + "# AGF overlay\n", + "if agf_results is not None:\n", + " utilmax_lr_val = 0.1\n", + " f = utilmax_lr_val / lr\n", + " agf_times = agf_results[\"t\"] + [epochs]\n", + " agf_losses = agf_results[\"losses\"] + [agf_results[\"losses\"][-1]]\n", + " ax.step(f * np.array(agf_times), agf_losses, where=\"post\", lw=2, ls=\"dashed\", color=\"k\", label=\"AGF\")\n", + " ax.legend(fontsize=14)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_ylabel(\"Train Loss\", fontsize=18)\n", + "ax.set_title(f\"GD vs AGF on $C_{{{p}}}$\", fontsize=20)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cn_loss_agf.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "group-agf", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/cnxcn.ipynb b/notebooks/cnxcn.ipynb new file mode 100644 index 0000000..3307c24 --- /dev/null +++ b/notebooks/cnxcn.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Binary Group Composition on $C_n \\times C_n$\n", + "\n", + "**Group:** Product of cyclic groups $C_n \\times C_n$ of order $n^2$. \n", + "**Task:** Given encodings of two group elements $g_1, g_2 \\in C_n \\times C_n$, predict the encoding of their product. \n", + "**Sequence length:** $k = 2$ (binary composition). \n", + "**Architecture:** `TwoLayerNet` with square nonlinearity. \n", + "**Key result:** The network learns one irreducible representation at a time." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "import src.dataset as dataset\n", + "import src.model as model\n", + "import src.power as power\n", + "import src.template as template\n", + "import src.train as train_mod\n", + "import src.viz as viz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", + "\n", + "seed = 47\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "n = 3 if TEST_MODE else 5\n", + "group_size = n * n\n", + "\n", + "hidden_size = 32 if TEST_MODE else 128\n", + "epochs = 2 if TEST_MODE else 1000\n", + "lr = 0.001\n", + "init_scale = 1e-2\n", + "batch_size = 32 if TEST_MODE else 128\n", + "\n", + "FIGURES_DIR = \"figures\"\n", + "os.makedirs(FIGURES_DIR, exist_ok=True)\n", + "\n", + "print(f\"Group: C_{n} x C_{n}, order {group_size}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a fixed CnxCn template with known Fourier structure\n", + "fourier_coef_mags = np.random.RandomState(seed).rand(n) * 10\n", + "tpl = template.fixed_cnxcn(image_length=n, fourier_coef_mags=fourier_coef_mags)\n", + "\n", + "# Build exhaustive dataset: all group_size^2 pairs\n", + "X, Y = dataset.cnxcn_dataset(tpl)\n", + "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n", + "\n", + "ds = TensorDataset(X_tensor, Y_tensor)\n", + "dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)\n", + "\n", + "print(f\"Dataset: {len(ds)} samples\")\n", + "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize template as 2D image\n", + "tpl_2d = tpl.reshape(n, n) if tpl.ndim == 1 else tpl\n", + "\n", + "fig, ax = plt.subplots(figsize=(5, 4))\n", + "im = ax.imshow(tpl_2d, cmap=\"RdBu_r\")\n", + "ax.set_xlabel(\"$C_n$ index\")\n", + "ax.set_ylabel(\"$C_n$ index\")\n", + "ax.set_title(f\"Template on $C_{{{n}}} \\\\times C_{{{n}}}$\")\n", + "plt.colorbar(im, ax=ax)\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cnxcn_template.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model and Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = model.TwoLayerNet(\n", + " group_size=group_size,\n", + " hidden_size=hidden_size,\n", + " nonlinearity=\"square\",\n", + " init_scale=init_scale,\n", + ")\n", + "net = net.to(device)\n", + "\n", + "criterion = nn.MSELoss()\n", + "opt = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))\n", + "\n", + "print(f\"Model: TwoLayerNet(group_size={group_size}, hidden={hidden_size})\")\n", + "print(f\"Optimizer: Adam(lr={lr})\")\n", + "print(f\"Training for {epochs} epochs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n", + " net,\n", + " dataloader,\n", + " criterion,\n", + " opt,\n", + " epochs=epochs,\n", + " verbose_interval=max(1, epochs // 10),\n", + " save_param_interval=max(1, epochs // 100),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute theoretical loss plateau levels\n", + "theory = power.theoretical_loss_levels_2d(tpl_2d)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.plot(loss_history, lw=4)\n", + "\n", + "for level in theory[\"levels\"]:\n", + " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_ylabel(\"Train Loss\", fontsize=18)\n", + "ax.set_title(f\"Training loss on $C_{{{n}}} \\\\times C_{{{n}}}$\", fontsize=20)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cnxcn_loss.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show a few predictions vs ground truth\n", + "net.load_state_dict(param_history[-1])\n", + "net.eval()\n", + "\n", + "n_examples = 3\n", + "indices = np.random.choice(len(Y_tensor), size=n_examples, replace=False)\n", + "\n", + "fig, axes = plt.subplots(n_examples, 2, figsize=(8, 3 * n_examples))\n", + "\n", + "with torch.no_grad():\n", + " preds = net(X_tensor[indices]).detach().cpu().numpy()\n", + " truths = Y_tensor[indices].detach().cpu().numpy()\n", + "\n", + "for i in range(n_examples):\n", + " axes[i, 0].imshow(truths[i].reshape(n, n), cmap=\"RdBu_r\")\n", + " axes[i, 0].set_title(\"Ground truth\")\n", + " axes[i, 1].imshow(preds[i].reshape(n, n), cmap=\"RdBu_r\")\n", + " axes[i, 1].set_title(\"Prediction\")\n", + "\n", + "plt.suptitle(f\"Predictions on $C_{{{n}}} \\\\times C_{{{n}}}$ (epoch {final_epoch})\", fontsize=16)\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/cnxcn_predictions.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "group-agf", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/dihedral.ipynb b/notebooks/dihedral.ipynb deleted file mode 100644 index a7bb7ca..0000000 --- a/notebooks/dihedral.ipynb +++ /dev/null @@ -1,1064 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51d11caf-0971-4324-b63b-819b714a9c3c", - "metadata": {}, - "source": [ - "# Diehdral 1D" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random\n", - "import torch\n", - "import torch.nn as nn\n", - "from tqdm import tqdm\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "from matplotlib.animation import FuncAnimation\n", - "from matplotlib.ticker import FormatStrFormatter\n", - "from matplotlib.ticker import FuncFormatter\n", - "from matplotlib.ticker import MaxNLocator" - ] - }, - { - "cell_type": "markdown", - "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88", - "metadata": {}, - "outputs": [], - "source": [ - "def one_hot(p):\n", - " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n", - " vec = np.zeros(p)\n", - " vec[0] = 1\n", - " return vec\n", - "\n", - "def generate_template(p, magnitude, exponent):\n", - " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n", - " template = np.ones(p) # Base term (DC component)\n", - " for freq in range(1, p):\n", - " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n", - " return template / p\n", - "\n", - "def generate_fixed_template(p):\n", - " # Generate template array from Fourier spectrum\n", - " spectrum = np.zeros(p, dtype=complex)\n", - " \n", - " # Set only three frequencies with specific amplitudes\n", - " spectrum[1] = 10 # Positive frequency\n", - " spectrum[-1] = 10 # Negative frequency (conjugate)\n", - " spectrum[3] = 5 # Second frequency\n", - " spectrum[-3] = 5 # Its conjugate\n", - " spectrum[5] = 2.5 # Third frequency \n", - " spectrum[-5] = 2.5 # Its conjugate\n", - " \n", - " # Generate signal from spectrum\n", - " template = np.fft.ifft(spectrum).real\n", - "\n", - " return template\n", - "\n", - "def DihedralDataset(p, template):\n", - " # Initialize data arrays\n", - " X = np.zeros((4*p * p, 2, p)) # Shape: (p^2, 2, p)\n", - " Y = np.zeros((4*p * p, p)) # Shape: (p^2, p)\n", - " \n", - " # Generate the dataset\n", - " idx = 0\n", - " for a in range(p):\n", - " for b in range(p):\n", - " for sa in (-1, 1):\n", - " for sb in (-1, 1): \n", - " X[idx, 0, :] = np.roll(template, a)[::sa] # reverse or does not reverse the order, depending on sa.\n", - " X[idx, 1, :] = np.roll(template, b)[::sb]\n", - " Y[idx, :] = np.roll(np.roll(template, a)[::sa], b)[::sb]\n", - " idx += 1\n", - " \n", - " return X, Y" - ] - }, - { - "cell_type": "markdown", - "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063", - "metadata": {}, - "source": [ - "## Architecture" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a", - "metadata": {}, - "outputs": [], - "source": [ - "class TwoLayerNet(nn.Module):\n", - " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n", - " super(TwoLayerNet, self).__init__()\n", - " \n", - " # Store dimensions\n", - " self.p = p\n", - " self.hidden_size = hidden_size\n", - " self.nonlinearity = nonlinearity\n", - " self.init_scale = init_scale\n", - " self.output_scale = output_scale\n", - " \n", - " # Initialize parameters \n", - " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n", - " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n", - " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n", - " print(f\"Initialized U with shape {self.U.shape}\")\n", - " print(f\"Initialized V with shape {self.V.shape}\")\n", - " print(f\"Initialized W with shape {self.W.shape}\")\n", - "\n", - " def forward(self, x):\n", - " print(f\"Input x shape: {x.shape}\")\n", - " # First layer (linear and combined)\n", - " x1 = x[:, :self.p] @ self.U.T\n", - " print(f\"x1 (x @ U.T) shape: {x1.shape}\")\n", - " x2 = x[:, self.p:] @ self.V.T\n", - " print(f\"x2 (x @ V.T) shape: {x2.shape}\")\n", - " x_combined = x1 + x2\n", - " print(f\"x_combined (x1 + x2) shape: {x_combined.shape}\")\n", - "\n", - " # Apply nonlinearity activation\n", - " if self.nonlinearity == 'relu':\n", - " x_combined = torch.relu(x_combined)\n", - " print(\"Applied ReLU nonlinearity\")\n", - " elif self.nonlinearity == 'square':\n", - " x_combined = x_combined**2\n", - " print(\"Applied square nonlinearity\")\n", - " elif self.nonlinearity == 'linear':\n", - " x_combined = x_combined\n", - " print(\"Applied linear (identity) nonlinearity\")\n", - " elif self.nonlinearity == 'tanh':\n", - " x_combined = torch.tanh(x_combined)\n", - " print(\"Applied tanh nonlinearity\")\n", - " elif self.nonlinearity == 'gelu':\n", - " gelu = torch.nn.GELU()\n", - " x_combined = gelu(x_combined)\n", - " print(\"Applied GELU nonlinearity\")\n", - " else:\n", - " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n", - "\n", - " # Second layer (linear)\n", - " x_out = x_combined @ self.W\n", - " print(f\"x_out (x_combined @ W) shape: {x_out.shape}\")\n", - "\n", - " # Feature learning scaling\n", - " x_out *= self.output_scale\n", - " print(f\"x_out after scaling with output_scale={self.output_scale}: shape {x_out.shape}\")\n", - " \n", - " return x_out" - ] - }, - { - "cell_type": "markdown", - "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168", - "metadata": {}, - "source": [ - "## Optimization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1035f81c-e877-4655-8640-4e4c3d323af8", - "metadata": {}, - "outputs": [], - "source": [ - "def test_accuracy(model, dataloader):\n", - " correct = 0\n", - " total = 0\n", - " print(\"Starting test_accuracy evaluation...\")\n", - " \n", - " with torch.no_grad(): # Disable gradient calculation for evaluation\n", - " for i, (inputs, labels) in enumerate(dataloader):\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - " print(f\"Batch {i+1}: inputs reshaped\")\n", - " outputs = model(inputs)\n", - " print(f\"Batch {i+1}: model forward pass done\")\n", - " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n", - " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n", - " correct += (predicted == true_labels).sum().item()\n", - " total += labels.size(0)\n", - " print(f\"Batch {i+1}: accuracy updated (correct={correct}, total={total})\")\n", - " \n", - " accuracy = 100 * correct / total\n", - " print(f\"Final test accuracy: {accuracy:.2f}%\")\n", - " return accuracy\n", - "\n", - "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n", - " print(\"Starting training loop...\")\n", - " model.train() # Set the model to training mode\n", - " print(\"Model set to train mode.\")\n", - " loss_history = [] # List to store loss values\n", - " accuracy_history = []\n", - " param_history = []\n", - "\n", - " for epoch in range(epochs):\n", - " print(f\"Epoch {epoch+1} started.\")\n", - " running_loss = 0.0\n", - " for batch_idx, (inputs, labels) in enumerate(dataloader):\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - " print(f\" Batch {batch_idx+1}: inputs reshaped\")\n", - "\n", - " optimizer.zero_grad() # Zero gradients\n", - " print(f\" Batch {batch_idx+1}: optimizer gradients zeroed\")\n", - " outputs = model(inputs) # Forward pass\n", - " print(f\" Batch {batch_idx+1}: model forward pass done\")\n", - " loss = criterion(outputs, labels) # Compute loss\n", - " print(f\" Batch {batch_idx+1}: loss computed ({loss.item():.4f})\")\n", - " loss.backward() # Backpropagation\n", - " print(f\" Batch {batch_idx+1}: backward pass done\")\n", - " optimizer.step() # Update weights\n", - " print(f\" Batch {batch_idx+1}: optimizer step done\")\n", - "\n", - " running_loss += loss.item()\n", - " print(f\" Batch {batch_idx+1}: running_loss updated ({running_loss:.4f})\")\n", - "\n", - " # Append the average loss for the epoch to loss_history\n", - " avg_loss = running_loss / len(dataloader)\n", - " loss_history.append(avg_loss)\n", - " print(f\"Epoch {epoch+1}: avg_loss appended ({avg_loss:.4f})\")\n", - "\n", - " # Append the accuracy\n", - " model.eval()\n", - " print(f\"Epoch {epoch+1}: model set to eval mode for accuracy check\")\n", - " accuracy = test_accuracy(model, dataloader)\n", - " accuracy_history.append(accuracy)\n", - " print(f\"Epoch {epoch+1}: accuracy appended ({accuracy:.2f}%)\")\n", - " model.train()\n", - " print(f\"Epoch {epoch+1}: model set back to train mode\")\n", - "\n", - " # Save current model parameters\n", - " current_params = {\n", - " \"U\": model.U.detach().cpu().clone(),\n", - " \"V\": model.V.detach().cpu().clone(),\n", - " \"W\": model.W.detach().cpu().clone()\n", - " }\n", - " param_history.append(current_params)\n", - " print(f\"Epoch {epoch+1}: model parameters saved\")\n", - "\n", - " # Print verbose information every `verbose_interval` epochs\n", - " if (epoch + 1) % verbose_interval == 0:\n", - " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n", - "\n", - " print(\"Training loop finished.\")\n", - " return loss_history, accuracy_history, param_history # Return loss history for plotting" - ] - }, - { - "cell_type": "markdown", - "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c", - "metadata": {}, - "source": [ - "## Plotting functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3", - "metadata": {}, - "outputs": [], - "source": [ - "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n", - " # Y-axis ticks\n", - " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n", - " \n", - " # X-axis ticks\n", - " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n", - "\n", - " # Scientific notation formatting\n", - " if ax.get_yscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n", - " if ax.get_xscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n", - "\n", - " ax.xaxis.offsetText.set_fontsize(20)\n", - " ax.grid()\n", - "\n", - " # Customize spines\n", - " for spine in [\"top\", \"right\"]:\n", - " ax.spines[spine].set_visible(False)\n", - " for spine in [\"left\", \"bottom\"]:\n", - " ax.spines[spine].set_linewidth(3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a", - "metadata": {}, - "outputs": [], - "source": [ - "def get_power(points):\n", - " p = len(points)\n", - " num_coefficients = (p // 2) + 1\n", - " \n", - " # Perform FFT and calculate power spectrum\n", - " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n", - " power = np.abs(ft[:num_coefficients])**2 / p\n", - " \n", - " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n", - " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n", - " power[1:num_coefficients - 1] *= 2\n", - " else: # p is odd, no Nyquist frequency\n", - " power[1:] *= 2\n", - "\n", - " # Confirm the power sum approximates the squared norm of points\n", - " total_power = np.sum(power)\n", - " norm_squared = np.linalg.norm(points)**2\n", - " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n", - " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n", - "\n", - " return np.arange(num_coefficients), power\n", - "\n", - "def interpolate(ax, points, color, continuous, alpha=1.0):\n", - " p = len(points)\n", - " if continuous:\n", - " # Perform Fourier Transform\n", - " ft = np.fft.fft(points)\n", - " \n", - " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n", - " num_coefficients = (p // 2) + 1\n", - " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n", - " \n", - " # Create a dense set of x-values for smooth interpolation\n", - " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n", - " curr_val = np.zeros(xs.shape, dtype=complex)\n", - " \n", - " # Use only non-negative frequencies for interpolation\n", - " for freq in range(num_coefficients):\n", - " theta = np.angle(ft[freq])\n", - " r = np.abs(ft[freq]) / p\n", - " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n", - " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n", - " r *= 2\n", - " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n", - "\n", - " # Plot the real part (since output is real-valued)\n", - " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n", - " else:\n", - " ax.plot(np.arange(p), points, color=color, alpha=alpha) " - ] - }, - { - "cell_type": "markdown", - "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7", - "metadata": {}, - "source": [ - "## Gradient Descent Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bcd15c5a-5745-41ba-b015-48e403160c7e", - "metadata": {}, - "outputs": [], - "source": [ - "seed = 0 # or any integer you like\n", - "random.seed(seed)\n", - "np.random.seed(seed)\n", - "torch.manual_seed(seed)\n", - "torch.cuda.manual_seed_all(seed) # if using GPU\n", - "\n", - "# TEST_MODE: Reduce p, hidden_size and epochs for faster automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "\n", - "# Data Generation using the new function\n", - "p = 10 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n", - "\n", - "# Get base vector\n", - "# template = generate_template(p, 2, 1.0)\n", - "# template = one_hot(p)\n", - "template = generate_fixed_template(p)\n", - "\n", - "# Mean center template\n", - "template -= np.mean(template)\n", - "\n", - "# Generate dataset using numpy\n", - "X, Y = DihedralDataset(p, template)\n", - "\n", - "# Convert to PyTorch tensors\n", - "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n", - "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n", - "\n", - "# Create a TensorDataset and DataLoader\n", - "dataset = TensorDataset(X_tensor, Y_tensor)\n", - "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n", - "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n", - "\n", - "# Initialize model\n", - "hidden_size = 6 if TEST_MODE else 6 * 3 # Reduced in test mode\n", - "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-2, output_scale=1e0)\n", - "\n", - "# Create loss function\n", - "loss = nn.MSELoss()\n", - "\n", - "# Construct optimizer\n", - "lr, mom = 0.01, 0.9\n", - "optimizer = optim.SGD(model.parameters(), lr=lr, momentum=mom)\n", - "# optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n", - "\n", - "# Train the model\n", - "epochs = 2 if TEST_MODE else 10000\n", - "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//100))" - ] - }, - { - "cell_type": "markdown", - "id": "eae371c4-1405-4ac5-982c-0ebacb688ed7", - "metadata": {}, - "source": [ - "## AGF Numerics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "489e82e1-61c8-43e6-b260-fd96c815dec8", - "metadata": {}, - "outputs": [], - "source": [ - "class ModsumSubNetwork(nn.Module):\n", - " \n", - " def __init__(self, d_in, d_out, init_scale):\n", - " super().__init__()\n", - " assert d_in%2 == 0\n", - " self.p = d_in // 2\n", - " self.u = nn.Linear(self.p, 1, bias=False)\n", - " self.v = nn.Linear(self.p, 1, bias=False)\n", - " self.w = nn.Linear(1, d_out, bias=False)\n", - " with torch.no_grad():\n", - " self.w.weight.mul_(init_scale)\n", - " self.u.weight.mul_(init_scale)\n", - " self.v.weight.mul_(init_scale)\n", - " self.active = False\n", - " self.util_acc = 0\n", - " self.c_a = 1/self.get_norm() - 1\n", - " \n", - " self.normalize()\n", - " \n", - " def get_norm(self):\n", - " sqnorm = lambda x: torch.linalg.norm(x.weight)**2\n", - " norm = torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n", - " return norm\n", - " \n", - " def reinitialize(self, u, v, w):\n", - " with torch.no_grad():\n", - " self.u.weight.copy_(u)\n", - " self.v.weight.copy_(v)\n", - " self.w.weight.copy_(w)\n", - " self.c_a = 1/self.get_norm() - 1\n", - " \n", - " def forward(self, x):\n", - " x1 = x[:, :self.p]\n", - " x2 = x[:, self.p:]\n", - " return self.w((self.u(x1) + self.v(x2))**2)\n", - " \n", - " def normalize(self):\n", - " norm = self.get_norm()\n", - " with torch.no_grad():\n", - " self.w.weight.div_(norm)\n", - " self.u.weight.div_(norm)\n", - " self.v.weight.div_(norm)\n", - " \n", - " def utility_step(self, x, residual, learning_rate):\n", - " f_i = self(x)\n", - " util = torch.einsum('nd,nd->n', f_i, residual).mean()\n", - " self.util_acc += 3 * learning_rate * util.item()\n", - " norm_th = 1/(1 + self.c_a - self.util_acc)\n", - " \n", - " util.backward()\n", - " with torch.no_grad():\n", - " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n", - " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n", - " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n", - " self.w.weight.grad.zero_()\n", - " self.u.weight.grad.zero_()\n", - " self.v.weight.grad.zero_()\n", - " self.normalize()\n", - "\n", - "\n", - "class ModsumNetwork(nn.Module):\n", - " \n", - " def __init__(self, d_in, d_out, init_scale, width=100):\n", - " super().__init__()\n", - " self.d_in = d_in\n", - " self.d_out = d_out\n", - " self.width = width\n", - " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n", - " self.neurons = nn.ModuleList(neurons)\n", - " self.set_mode(\"utilmax\")\n", - " \n", - " def load_init(self, U, V, W):\n", - " for i, n in enumerate(self.neurons):\n", - " u, v, w = U[i], V[i], W[i][:, None]\n", - " n.reinitialize(u, v, w)\n", - "\n", - " def dormant(self):\n", - " return [neuron for neuron in self.neurons if not neuron.active]\n", - " \n", - " def active(self):\n", - " return [neuron for neuron in self.neurons if neuron.active]\n", - "\n", - " \n", - " def set_mode(self, mode):\n", - " if mode not in [\"utilmax\", \"costmin\"]:\n", - " raise ValueError(\"mode must be utilmax or costmin\")\n", - " self.mode = mode\n", - " for neuron in self.neurons:\n", - " grad_on = (mode==\"utilmax\") ^ neuron.active\n", - " for param in neuron.parameters():\n", - " param.requires_grad = grad_on\n", - " \n", - " def forward(self, x):\n", - " if not np.any([n.active for n in self.neurons]):\n", - " return torch.zeros(x.shape[0], self.d_out)\n", - " else:\n", - " outputs = torch.stack([neuron(x) for neuron in self.neurons if neuron.active], dim=0)\n", - " return torch.sum(outputs, dim=0)\n", - "\n", - "\n", - "def train_agf(X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None, \n", - " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4):\n", - " \n", - " # Initialize\n", - " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n", - " if from_init:\n", - " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n", - " assert d_in == U.shape[1]*2\n", - " assert d_out == W.shape[1]\n", - " width = U.shape[0]\n", - " net = ModsumNetwork(d_in, d_out, init_sz, width=width)#.cuda()\n", - " net.load_init(U, V, W)\n", - " else:\n", - " net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)#.cuda()\n", - " X_train.requires_grad = False\n", - " \n", - " def update_results(results, t):\n", - " results[\"t\"].append(t)\n", - " residual = (Y_train - net(X_train))\n", - " residual = residual.detach()\n", - " results[\"residuals\"].append(residual)\n", - " loss = (residual**2).mean().item()\n", - " results[\"losses\"].append(loss)\n", - " results[\"models\"].append(net.state_dict())\n", - " results[\"pred\"].append(net(X_train).detach().cpu().clone())\n", - " \n", - " results = {\n", - " \"t\": [],\n", - " \"residuals\": [],\n", - " \"losses\": [],\n", - " \"models\": [],\n", - " \"pred\": [],\n", - " }\n", - "\n", - " t = 0\n", - " update_results(results, t)\n", - " for _ in tqdm(range(agf_steps)):\n", - " \n", - " # Utility Maximization\n", - " residual = (1/d_out) * 2*(Y_train - net(X_train))\n", - " residual = residual.detach()\n", - " iters = 0\n", - " mode = \"utilmax\"\n", - " while mode == \"utilmax\":\n", - " for n in net.neurons:\n", - " if n.active:\n", - " continue\n", - " n.utility_step(X_train, residual, utilmax_lr)\n", - " if n.util_acc > n.c_a:\n", - " n.active = True\n", - " mode = \"costmin\"\n", - " # break\n", - " iters += 1\n", - " net.set_mode(mode)\n", - " t += iters\n", - "\n", - " # Cost Minimization\n", - " optimizer = torch.optim.SGD(net.parameters(), lr=costmin_lr, momentum=0.9)\n", - " for i in range(int(costmin_maxiter)):\n", - " optimizer.zero_grad(set_to_none=False)\n", - " residual = Y_train - net(X_train)\n", - " loss = (residual ** 2).mean()\n", - " loss.backward()\n", - " optimizer.step()\n", - " net.set_mode(\"utilmax\")\n", - "\n", - " \n", - " print(f\"loss: {loss.item():.5f}\")\n", - " update_results(results, t)\n", - "\n", - " # Check for Termination\n", - " if not net.dormant() or loss.item() < loss_thresh:\n", - " break\n", - " \n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5b3eddcc-8c3e-45dd-8f57-45585a021f5d", - "metadata": {}, - "outputs": [], - "source": [ - "costmin_lr = 0.01\n", - "utilmax_lr = 0.1\n", - "results = train_agf(X_tensor, Y_tensor, init_sz=model.init_scale, agf_steps=50, from_init=param_history[0],\n", - " utilmax_lr=utilmax_lr, costmin_lr=costmin_lr,\n", - " costmin_maxiter=1e4, loss_thresh=1e-4)" - ] - }, - { - "cell_type": "markdown", - "id": "0f48aebc-a439-405a-a057-3f5c24cca91a", - "metadata": {}, - "source": [ - "## Plot Loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff46febe-abb5-459a-bb06-a18a26afb967", - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", - "ax.plot(list(loss_history), lw=4)\n", - "\n", - "# for lossval in results[\"losses\"]:\n", - "# ax.axhline(lossval, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n", - "\n", - "f = utilmax_lr / (lr/(1-mom))\n", - "# for t in results[\"t\"]:\n", - "# ax.axvline(f*t, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n", - "\n", - "# times = results[\"t\"] + [epochs]\n", - "# AGF_losses = results[\"losses\"] + [results[\"losses\"][-1]]\n", - "# ax.step(f*np.array(times), AGF_losses, where=\"post\", lw=2, ls='dashed', color=\"k\")\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "for k, alpha in enumerate(alpha_values):\n", - " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n", - "\n", - "ax.set_xscale(\"log\")\n", - "ax.set_yscale(\"log\")\n", - "ax.set_xlim(1e1, 1e4)\n", - "ax.set_ylim(1e-1, 1e1)\n", - "ax.set_xlabel('Epochs', fontsize=24)\n", - "ax.set_ylabel('Train Loss', fontsize=24)\n", - "\n", - "style_axes(ax)\n", - "plt.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "40b851e7-6256-43cd-b9f3-aca38db04917", - "metadata": {}, - "source": [ - "## Power Spectrum of output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68b25ca9-6339-49dd-9d45-577a51798a25", - "metadata": {}, - "outputs": [], - "source": [ - "# === SETTINGS ===\n", - "p = Y_tensor.shape[1]\n", - "num_freqs = p // 2 + 1\n", - "\n", - "# Compute template power spectrum\n", - "template_ft = np.fft.rfft(template)\n", - "template_power = np.abs(template_ft)[:num_freqs]\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "# === Theory lines ===\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "theta0 = np.sqrt(2) * model.init_scale\n", - "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n", - "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n", - "step_size = 2 * coef * lr / (1 - mom)\n", - "\n", - "\n", - "# Color settings\n", - "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n", - "manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - "}\n", - "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n", - "\n", - "# Compute output power over time (GD)\n", - "num_points = 1000\n", - "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n", - "powers_over_time = []\n", - "\n", - "for step in steps:\n", - " model.load_state_dict(param_history[step])\n", - " model.eval()\n", - " with torch.no_grad():\n", - " outputs = model(X_tensor)\n", - " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", - " avg_power = np.mean(np.abs(ft), axis=0)\n", - " powers_over_time.append(avg_power)\n", - "\n", - "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n", - "\n", - "\n", - "# # Compute output power over time (AGF)\n", - "# f = utilmax_lr / (lr/(1-mom))\n", - "# AGF_steps = results[\"t\"]\n", - "# powers_over_time_AGF = []\n", - "# for i, step in enumerate(AGF_steps):\n", - "# outputs = results[\"pred\"][i]\n", - "# ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", - "# avg_power = np.mean(np.abs(ft), axis=0)\n", - "# powers_over_time_AGF.append(avg_power)\n", - "# powers_over_time_AGF = np.array(powers_over_time_AGF) # shape: (steps, freqs)\n", - "# AGF_steps = [f * t for t in AGF_steps]\n", - "\n", - "# AGF_steps.append(epochs)\n", - "# powers_over_time_AGF = np.vstack([\n", - "# powers_over_time_AGF,\n", - "# powers_over_time_AGF[-1, :]\n", - "# ])\n", - "\n", - "# === PLOTTING ===\n", - "fig, ax = plt.subplots(figsize=(6, 7))\n", - "\n", - "for k in range(num_freqs):\n", - " color = colors[k]\n", - " label = fr\"$\\xi = {k}$\" if k in [1, 3, 5] else None\n", - " ax.plot(steps, powers_over_time[:, k], color=color, lw=3, label=label)\n", - " label_agf = 'AGF' if k == 10 else None\n", - " #ax.step(AGF_steps, powers_over_time_AGF[:, k], color='k', lw=2, ls='dashed', where=\"post\", label=label_agf)\n", - " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n", - "\n", - "for k, tau in enumerate(tau_values):\n", - " color = colors[freq[k]]\n", - " ax.axvline(x=tau / step_size, color=color, linestyle='dashed', linewidth=2, alpha=0.5)\n", - "\n", - " # Add arrow at intersection\n", - " x = tau / step_size\n", - " y = template_power[freq[k]]\n", - " #draw an arrow from the lower bound to the right\n", - " #use default color cycle\n", - " # ax.arrow(1.04 * x, y + 0.5, 1.5 * x, 0, \n", - " # head_width=0.2, head_length=x*0.2, length_includes_head=True,\n", - " # fc=color, ec=color, lw=4)\n", - "\n", - "# # Add vertical lines if needed\n", - "# for step in time_steps:\n", - "# ax.axvline(x=step, color='gray', alpha=0.5, linestyle='solid', linewidth=2)\n", - "\n", - "# Labeling and formatting\n", - "ax.set_xscale('log')\n", - "ax.set_xlim(5e1, 2e6)\n", - "ax.set_xticks([1000, 10000, 100000, epochs-1])\n", - "ax.set_ylabel(\"Power\", fontsize=24)\n", - "ax.set_xlabel(\"Epochs\", fontsize=24)\n", - "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper right', bbox_to_anchor=(1, 0.9), labelspacing=0.25)\n", - "\n", - "style_axes(ax)\n", - "ax.set_xticks([1000, 10000, 100000, epochs-1])\n", - "ax.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd", - "metadata": {}, - "source": [ - "## Plot outputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e333d1ab-1501-434f-86d2-82c10bb58f11", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# Choose time steps to visualize\n", - "# steps_to_show = [1000, 10000, 100000, epochs-1]\n", - "steps_to_show = [100, 1000, epochs-1]\n", - "num_samples = 1 # how many examples to plot per row\n", - "p = Y_tensor.shape[1]\n", - "x = np.arange(p)\n", - "\n", - "fig, axes = plt.subplots(len(steps_to_show), 1, figsize=(6, 6), sharex=True)\n", - "\n", - "for row, step in enumerate(steps_to_show):\n", - " # Load weights at this step\n", - " model.load_state_dict(param_history[step])\n", - " model.eval()\n", - "\n", - " indices = np.random.choice(len(Y_tensor), size=num_samples, replace=False)\n", - " with torch.no_grad():\n", - " preds = model(X_tensor[indices]).detach().cpu().numpy()\n", - " truths = Y_tensor[indices].detach().cpu().numpy()\n", - "\n", - " ax = axes[row]\n", - " for i, idx in enumerate(indices):\n", - " a = idx // p\n", - " b = idx % p\n", - " label_true = r\"$(a + b) \\cdot x$\"\n", - " label_pred = r\"$f(a \\cdot x, b \\cdot x)$\"\n", - "\n", - " # Plot ground truth\n", - " interpolate(ax, truths[i], color=f\"C{i}\", alpha=0.9, continuous=True)\n", - " ax.scatter(x, truths[i], color=f\"C{i}\", s=30, alpha=0.9, label=label_true)\n", - "\n", - " # Plot prediction\n", - " interpolate(ax, preds[i], color='k', alpha=1.0, continuous=True)\n", - " ax.scatter(x, preds[i], color='k', s=30, alpha=0.7, label=label_pred)\n", - "\n", - " style_axes(ax, numyticks=3, labelsize=12)\n", - " ax.grid(False)\n", - " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(step))}}}$\", fontsize=20)\n", - "\n", - " # Only bottom row gets x-ticks\n", - " if row < len(steps_to_show) - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # ax.legend(loc='best', fontsize=12, title=fr\"$a = {a}, b = {b}$\", handlelength=0, labelspacing=0.1, title_fontsize=14, frameon=False)\n", - " ax.legend(\n", - " loc='center left',\n", - " bbox_to_anchor=(0.95, 0.5), # X slightly beyond the right edge, Y centered\n", - " fontsize=8,\n", - " title=fr\"$a = {a}, b = {b}$\",\n", - " title_fontsize=10,\n", - " handlelength=0,\n", - " labelspacing=0.1,\n", - " frameon=False\n", - " )\n", - "\n", - "# axes[-1].set_xlabel(\"Output Index\", fontsize=20)\n", - "plt.tight_layout()\n", - "plt.savefig(\"predictions.pdf\", bbox_inches='tight')" - ] - }, - { - "cell_type": "markdown", - "id": "b267424b-a0e5-47e3-9e01-1dc41e05e026", - "metadata": {}, - "source": [ - "## Plot Weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9de707bf-838e-4384-8150-3d8fe4586fc3", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "import matplotlib.gridspec as gridspec\n", - "\n", - "# Steps and corresponding highlighted frequencies\n", - "\n", - "#steps = [1000, 10000, 100000, epochs-1]\n", - "steps = [100, 1000, epochs-1]\n", - "highlight_freqs_list = [[], [1], [3], [5]]\n", - "\n", - "num_rows, num_cols = len(steps), 3\n", - "\n", - "# Use gridspec to control layout\n", - "fig = plt.figure(figsize=(24, 6), constrained_layout=True)\n", - "gs = gridspec.GridSpec(num_rows, num_cols, width_ratios=[1.1, 1.1, 2.0], wspace=0.1, hspace=0.1)\n", - "axes = np.empty((num_rows, num_cols), dtype=object)\n", - "\n", - "# Create axes\n", - "for row in range(num_rows):\n", - " for col in range(num_cols):\n", - " if col == 2:\n", - " ax = fig.add_subplot(gs[row, col], projection='polar')\n", - " else:\n", - " ax = fig.add_subplot(gs[row, col]) # \u2b05 no sharex anymore\n", - " axes[row, col] = ax\n", - "\n", - "num_freqs = None\n", - "for row, index in enumerate(steps):\n", - " highlight_freqs = highlight_freqs_list[row]\n", - " params = param_history[index]\n", - " W = params['W'].numpy()\n", - " h, p = W.shape\n", - "\n", - " if num_freqs is None:\n", - " num_freqs = p // 2 + 1\n", - " cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n", - " colors = [cmap(i) for i in range(num_freqs)]\n", - " manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - " }\n", - " freq_colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n", - "\n", - "\n", - " # === Column 1: Weights ===\n", - " ax = axes[row, 0]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " dom_idx = np.argmax(power)\n", - " color = freq_colors[dom_idx]\n", - " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n", - " x = np.linspace(0, p - 1, 500)\n", - " interpolate(ax, w, color=color, continuous=True, alpha=alpha)\n", - " ax.scatter(np.arange(p), w, color=color, s=10, alpha=alpha)\n", - " if row == 0: ax.set_title(\"Weights\", fontsize=24)\n", - " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(index))}}}$\", fontsize=20)\n", - " style_axes(ax, numyticks=3, numxticks=5, labelsize=12)\n", - " ax.grid(False)\n", - " if row < num_rows - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # === Column 2: Frequency Spectrum ===\n", - " ax = axes[row, 1]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " for k in range(len(power)):\n", - " color = freq_colors[k]\n", - " ax.vlines(k, 0, power[k], linewidth=4, color=color, alpha=0.4)\n", - " ax.scatter(k, power[k], color=color, s=50, alpha=0.7)\n", - " # ax.axhline(0, color='gray', linewidth=1, linestyle='--', alpha=0.4)\n", - " ax.set_xlim(-0.5, len(power) - 0.5)\n", - " ax.set_xticks(np.arange(len(power)))\n", - " if row == 0: ax.set_title(\"Frequency\", fontsize=24)\n", - " style_axes(ax, numyticks=3, numxticks=11, labelsize=12)\n", - " ax.grid(False)\n", - " if row < num_rows - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # === Column 3: Phase Polar Plot ===\n", - " ax = axes[row, 2]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " dom_idx = np.argmax(power)\n", - " phase = np.angle(ft[dom_idx])\n", - " norm = np.linalg.norm(w)\n", - " color = freq_colors[dom_idx]\n", - " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n", - " ax.plot([phase, phase], [0, norm], color=color, linewidth=2, alpha=alpha)\n", - " ax.scatter(phase, norm, color=color, s=40, alpha=alpha)\n", - " angles = np.arange(0, 360, 45)\n", - " # ax.set_thetagrids(angles, [f\"{a}\u00b0\" if a in [45,135,225,315] else \"\" for a in angles])\n", - " ax.set_thetagrids(angles, [\"\" for a in angles])\n", - " ax.set_yticklabels([])\n", - " ax.spines['polar'].set_linewidth(2)\n", - " if row == 0: ax.set_title(\"Phase\", fontsize=24)\n", - "\n", - "# Shift polar plots left to reduce whitespace\n", - "for row in range(num_rows):\n", - " ax = axes[row, 2]\n", - " pos = ax.get_position()\n", - " ax.set_position([pos.x0 - 0.155, pos.y0, pos.width, pos.height])\n", - "\n", - "plt.savefig(\"W-weights.pdf\", bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f5f56f9-7055-4056-9a18-7d91b3be50f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/notebooks/dn.ipynb b/notebooks/dn.ipynb new file mode 100644 index 0000000..85d2fd5 --- /dev/null +++ b/notebooks/dn.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Binary Group Composition on $D_n$ (Dihedral Group)\n", + "\n", + "**Group:** Dihedral group $D_n$ of order $2n$ (rotations and reflections of a regular $n$-gon). \n", + "**Task:** Given encodings of two group elements $g_1, g_2 \\in D_n$, predict the encoding of their product $g_1 \\cdot g_2$. \n", + "**Sequence length:** $k = 2$ (binary composition). \n", + "**Architecture:** `TwoLayerNet` with square nonlinearity. \n", + "**Key result:** The network learns one irreducible representation at a time, producing a staircase in the training loss." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from escnn.group import DihedralGroup\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "import src.dataset as dataset\n", + "import src.model as model\n", + "import src.optimizer as optimizer\n", + "import src.power as power\n", + "import src.template as template\n", + "import src.train as train_mod\n", + "import src.viz as viz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", + "\n", + "seed = 0\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "n = 5 # D_5 has order 2*5 = 10\n", + "group = DihedralGroup(n)\n", + "group_size = group.order()\n", + "\n", + "hidden_size = 20 if TEST_MODE else 180\n", + "epochs = 2 if TEST_MODE else 2000\n", + "lr = 0.01\n", + "init_scale = 1e-3\n", + "\n", + "FIGURES_DIR = \"figures\"\n", + "os.makedirs(FIGURES_DIR, exist_ok=True)\n", + "\n", + "print(f\"Group: D_{n}, order {group_size}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a template with known Fourier structure on D_n\n", + "# D_5 has 4 irreps: two 1D and two 2D\n", + "# fourier_coef_diag_values: one value per irrep\n", + "fourier_coef_diag_values = [0.0, 5.0, 30.0, 300.0]\n", + "tpl = template.fixed_group(group, fourier_coef_diag_values)\n", + "\n", + "# Build exhaustive dataset: all group_size^2 pairs\n", + "X, Y = dataset.group_dataset(group, tpl)\n", + "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n", + "\n", + "ds = TensorDataset(X_tensor, Y_tensor)\n", + "dataloader = DataLoader(ds, batch_size=len(ds), shuffle=False)\n", + "\n", + "print(f\"Dataset: {len(ds)} samples (all {group_size}x{group_size} pairs)\")\n", + "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize template and its group power spectrum\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n", + "\n", + "ax1.bar(range(group_size), tpl, color=\"black\")\n", + "ax1.set_xlabel(\"Group element\")\n", + "ax1.set_ylabel(\"Template value\")\n", + "ax1.set_title(f\"Template $t$ on $D_{{{n}}}$\")\n", + "\n", + "gp = power.GroupPower(tpl, group)\n", + "pwr = gp.group_power_spectrum()\n", + "ax2.bar(range(len(pwr)), pwr, color=\"steelblue\")\n", + "ax2.set_xlabel(\"Irrep index\")\n", + "ax2.set_ylabel(\"Power\")\n", + "ax2.set_title(\"Power spectrum (by irrep)\")\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/dihedral_template.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model and Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = model.TwoLayerNet(\n", + " group_size=group_size,\n", + " hidden_size=hidden_size,\n", + " nonlinearity=\"square\",\n", + " init_scale=init_scale,\n", + ")\n", + "net = net.to(device)\n", + "\n", + "criterion = nn.MSELoss()\n", + "opt = optimizer.PerNeuronScaledSGD(net, lr=lr, degree=3)\n", + "\n", + "print(f\"Model: TwoLayerNet(group_size={group_size}, hidden={hidden_size}, init_scale={init_scale})\")\n", + "print(f\"Optimizer: PerNeuronScaledSGD(lr={lr}, degree=3)\")\n", + "print(f\"Training for {epochs} epochs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n", + " net,\n", + " dataloader,\n", + " criterion,\n", + " opt,\n", + " epochs=epochs,\n", + " verbose_interval=max(1, epochs // 10),\n", + " save_param_interval=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compute theoretical loss plateau levels\n", + "theory_levels = gp.loss_plateau_predictions()\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.plot(loss_history, lw=4)\n", + "\n", + "for level in theory_levels:\n", + " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_ylabel(\"Train Loss\", fontsize=18)\n", + "ax.set_title(f\"Training loss on $D_{{{n}}}$\", fontsize=20)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/dihedral_loss.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Power Spectrum Over Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use model_power_over_time from src/power.py\n", + "powers_over_time, power_steps = power.model_power_over_time(\n", + " group_name=\"dihedral\",\n", + " model=net,\n", + " param_history=param_history,\n", + " model_inputs=X_tensor,\n", + " group=group,\n", + ")\n", + "\n", + "# Reference: template power per irrep\n", + "template_pwr = gp.group_power_spectrum()\n", + "\n", + "# Plot\n", + "colors = [\"tab:blue\", \"tab:orange\", \"tab:red\", \"tab:green\", \"tab:brown\", \"tab:purple\"]\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "n_irreps = powers_over_time.shape[1]\n", + "for k in range(n_irreps):\n", + " color = colors[k] if k < len(colors) else f\"C{k}\"\n", + " ax.plot(power_steps, powers_over_time[:, k], color=color, lw=4, label=rf\"$\\rho_{{{k}}}$\")\n", + " ax.axhline(template_pwr[k], color=color, linestyle=\"dotted\", linewidth=2, alpha=0.5, zorder=-10)\n", + "\n", + "ax.set_xscale(\"log\")\n", + "ax.set_ylabel(\"Power\", fontsize=18)\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_title(f\"Power spectrum over training on $D_{{{n}}}$\", fontsize=20)\n", + "ax.legend(fontsize=12, title=\"Irrep\", title_fontsize=14, loc=\"upper left\", labelspacing=0.25)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/dihedral_power_spectrum.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Irreducible Representations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = viz.plot_irreps(group, show=False)\n", + "plt.savefig(f\"{FIGURES_DIR}/dihedral_irreps.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "group-agf", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/fourier_power_only.svg b/notebooks/fourier_power_only.svg deleted file mode 100644 index b39f1e3..0000000 --- a/notebooks/fourier_power_only.svg +++ /dev/null @@ -1,1054 +0,0 @@ - - - - - - - - 2026-02-05T16:49:52.817529 - image/svg+xml - - - Matplotlib v3.10.8, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/notebooks/loss-without-lines.svg b/notebooks/loss-without-lines.svg deleted file mode 100644 index e379d4c..0000000 --- a/notebooks/loss-without-lines.svg +++ /dev/null @@ -1,778 +0,0 @@ - - - - - - - - 2026-02-05T16:49:52.258832 - image/svg+xml - - - Matplotlib v3.10.8, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/notebooks/modular_arithmetic.ipynb b/notebooks/modular_arithmetic.ipynb deleted file mode 100644 index 8f4f355..0000000 --- a/notebooks/modular_arithmetic.ipynb +++ /dev/null @@ -1,1063 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51d11caf-0971-4324-b63b-819b714a9c3c", - "metadata": {}, - "source": [ - "# Modular Addition" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random\n", - "import torch\n", - "import torch.nn as nn\n", - "from tqdm import tqdm\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "from matplotlib.animation import FuncAnimation\n", - "from matplotlib.ticker import FormatStrFormatter\n", - "from matplotlib.ticker import FuncFormatter\n", - "from matplotlib.ticker import MaxNLocator" - ] - }, - { - "cell_type": "markdown", - "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88", - "metadata": {}, - "outputs": [], - "source": [ - "def one_hot(p):\n", - " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n", - " vec = np.zeros(p)\n", - " vec[0] = 1\n", - " return vec\n", - "\n", - "def generate_template(p, magnitude, exponent):\n", - " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n", - " template = np.ones(p) # Base term (DC component)\n", - " for freq in range(1, p):\n", - " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n", - " return template / p\n", - "\n", - "def generate_fixed_template(p):\n", - " # Generate template array from Fourier spectrum\n", - " spectrum = np.zeros(p, dtype=complex)\n", - " \n", - " # Set only three frequencies with specific amplitudes\n", - " spectrum[1] = 10 # Positive frequency\n", - " spectrum[-1] = 10 # Negative frequency (conjugate)\n", - " spectrum[3] = 5 # Second frequency\n", - " spectrum[-3] = 5 # Its conjugate\n", - " spectrum[5] = 2.5 # Third frequency \n", - " spectrum[-5] = 2.5 # Its conjugate\n", - " \n", - " # Generate signal from spectrum\n", - " template = np.fft.ifft(spectrum).real\n", - "\n", - " return template\n", - "\n", - "def ModularAdditionDataset(p, template):\n", - " # Initialize data arrays\n", - " X = np.zeros((p * p, 2, p)) # Shape: (p^2, 2, p)\n", - " Y = np.zeros((p * p, p)) # Shape: (p^2, p)\n", - " \n", - " # Generate the dataset\n", - " idx = 0\n", - " for a in range(p):\n", - " for b in range(p):\n", - " q = (a + b) % p # a + b mod p\n", - " X[idx, 0, :] = np.roll(template, a)\n", - " X[idx, 1, :] = np.roll(template, b)\n", - " Y[idx, :] = np.roll(template, q)\n", - " idx += 1\n", - " \n", - " return X, Y" - ] - }, - { - "cell_type": "markdown", - "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063", - "metadata": {}, - "source": [ - "## Architecture" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a", - "metadata": {}, - "outputs": [], - "source": [ - "class TwoLayerNet(nn.Module):\n", - " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n", - " super(TwoLayerNet, self).__init__()\n", - " \n", - " # Store dimensions\n", - " self.p = p\n", - " self.hidden_size = hidden_size\n", - " self.nonlinearity = nonlinearity\n", - " self.init_scale = init_scale\n", - " self.output_scale = output_scale\n", - " \n", - " # Initialize parameters \n", - " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n", - " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n", - " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n", - " print(f\"Initialized U with shape {self.U.shape}\")\n", - " print(f\"Initialized V with shape {self.V.shape}\")\n", - " print(f\"Initialized W with shape {self.W.shape}\")\n", - "\n", - " def forward(self, x):\n", - " print(f\"Input x shape: {x.shape}\")\n", - " # First layer (linear and combined)\n", - " x1 = x[:, :self.p] @ self.U.T\n", - " print(f\"x1 (x @ U.T) shape: {x1.shape}\")\n", - " x2 = x[:, self.p:] @ self.V.T\n", - " print(f\"x2 (x @ V.T) shape: {x2.shape}\")\n", - " x_combined = x1 + x2\n", - " print(f\"x_combined (x1 + x2) shape: {x_combined.shape}\")\n", - "\n", - " # Apply nonlinearity activation\n", - " if self.nonlinearity == 'relu':\n", - " x_combined = torch.relu(x_combined)\n", - " print(\"Applied ReLU nonlinearity\")\n", - " elif self.nonlinearity == 'square':\n", - " x_combined = x_combined**2\n", - " print(\"Applied square nonlinearity\")\n", - " elif self.nonlinearity == 'linear':\n", - " x_combined = x_combined\n", - " print(\"Applied linear (identity) nonlinearity\")\n", - " elif self.nonlinearity == 'tanh':\n", - " x_combined = torch.tanh(x_combined)\n", - " print(\"Applied tanh nonlinearity\")\n", - " elif self.nonlinearity == 'gelu':\n", - " gelu = torch.nn.GELU()\n", - " x_combined = gelu(x_combined)\n", - " print(\"Applied GELU nonlinearity\")\n", - " else:\n", - " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n", - "\n", - " # Second layer (linear)\n", - " x_out = x_combined @ self.W\n", - " print(f\"x_out (x_combined @ W) shape: {x_out.shape}\")\n", - "\n", - " # Feature learning scaling\n", - " x_out *= self.output_scale\n", - " print(f\"x_out after scaling with output_scale={self.output_scale}: shape {x_out.shape}\")\n", - " \n", - " return x_out" - ] - }, - { - "cell_type": "markdown", - "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168", - "metadata": {}, - "source": [ - "## Optimization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1035f81c-e877-4655-8640-4e4c3d323af8", - "metadata": {}, - "outputs": [], - "source": [ - "def test_accuracy(model, dataloader):\n", - " correct = 0\n", - " total = 0\n", - " print(\"Starting test_accuracy evaluation...\")\n", - " \n", - " with torch.no_grad(): # Disable gradient calculation for evaluation\n", - " for i, (inputs, labels) in enumerate(dataloader):\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - " print(f\"Batch {i+1}: inputs reshaped\")\n", - " outputs = model(inputs)\n", - " print(f\"Batch {i+1}: model forward pass done\")\n", - " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n", - " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n", - " correct += (predicted == true_labels).sum().item()\n", - " total += labels.size(0)\n", - " print(f\"Batch {i+1}: accuracy updated (correct={correct}, total={total})\")\n", - " \n", - " accuracy = 100 * correct / total\n", - " print(f\"Final test accuracy: {accuracy:.2f}%\")\n", - " return accuracy\n", - "\n", - "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n", - " print(\"Starting training loop...\")\n", - " model.train() # Set the model to training mode\n", - " print(\"Model set to train mode.\")\n", - " loss_history = [] # List to store loss values\n", - " accuracy_history = []\n", - " param_history = []\n", - "\n", - " for epoch in range(epochs):\n", - " print(f\"Epoch {epoch+1} started.\")\n", - " running_loss = 0.0\n", - " for batch_idx, (inputs, labels) in enumerate(dataloader):\n", - " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n", - " print(f\" Batch {batch_idx+1}: inputs reshaped\")\n", - "\n", - " optimizer.zero_grad() # Zero gradients\n", - " print(f\" Batch {batch_idx+1}: optimizer gradients zeroed\")\n", - " outputs = model(inputs) # Forward pass\n", - " print(f\" Batch {batch_idx+1}: model forward pass done\")\n", - " loss = criterion(outputs, labels) # Compute loss\n", - " print(f\" Batch {batch_idx+1}: loss computed ({loss.item():.4f})\")\n", - " loss.backward() # Backpropagation\n", - " print(f\" Batch {batch_idx+1}: backward pass done\")\n", - " optimizer.step() # Update weights\n", - " print(f\" Batch {batch_idx+1}: optimizer step done\")\n", - "\n", - " running_loss += loss.item()\n", - " print(f\" Batch {batch_idx+1}: running_loss updated ({running_loss:.4f})\")\n", - "\n", - " # Append the average loss for the epoch to loss_history\n", - " avg_loss = running_loss / len(dataloader)\n", - " loss_history.append(avg_loss)\n", - " print(f\"Epoch {epoch+1}: avg_loss appended ({avg_loss:.4f})\")\n", - "\n", - " # Append the accuracy\n", - " model.eval()\n", - " print(f\"Epoch {epoch+1}: model set to eval mode for accuracy check\")\n", - " accuracy = test_accuracy(model, dataloader)\n", - " accuracy_history.append(accuracy)\n", - " print(f\"Epoch {epoch+1}: accuracy appended ({accuracy:.2f}%)\")\n", - " model.train()\n", - " print(f\"Epoch {epoch+1}: model set back to train mode\")\n", - "\n", - " # Save current model parameters\n", - " current_params = {\n", - " \"U\": model.U.detach().cpu().clone(),\n", - " \"V\": model.V.detach().cpu().clone(),\n", - " \"W\": model.W.detach().cpu().clone()\n", - " }\n", - " param_history.append(current_params)\n", - " print(f\"Epoch {epoch+1}: model parameters saved\")\n", - "\n", - " # Print verbose information every `verbose_interval` epochs\n", - " if (epoch + 1) % verbose_interval == 0:\n", - " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n", - "\n", - " print(\"Training loop finished.\")\n", - " return loss_history, accuracy_history, param_history # Return loss history for plotting" - ] - }, - { - "cell_type": "markdown", - "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c", - "metadata": {}, - "source": [ - "## Plotting functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3", - "metadata": {}, - "outputs": [], - "source": [ - "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n", - " # Y-axis ticks\n", - " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n", - " \n", - " # X-axis ticks\n", - " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n", - " labelbottom=True, left=True, right=False,\n", - " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n", - " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n", - "\n", - " # Scientific notation formatting\n", - " if ax.get_yscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n", - " if ax.get_xscale() == 'linear':\n", - " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n", - "\n", - " ax.xaxis.offsetText.set_fontsize(20)\n", - " ax.grid()\n", - "\n", - " # Customize spines\n", - " for spine in [\"top\", \"right\"]:\n", - " ax.spines[spine].set_visible(False)\n", - " for spine in [\"left\", \"bottom\"]:\n", - " ax.spines[spine].set_linewidth(3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a", - "metadata": {}, - "outputs": [], - "source": [ - "def get_power(points):\n", - " p = len(points)\n", - " num_coefficients = (p // 2) + 1\n", - " \n", - " # Perform FFT and calculate power spectrum\n", - " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n", - " power = np.abs(ft[:num_coefficients])**2 / p\n", - " \n", - " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n", - " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n", - " power[1:num_coefficients - 1] *= 2\n", - " else: # p is odd, no Nyquist frequency\n", - " power[1:] *= 2\n", - "\n", - " # Confirm the power sum approximates the squared norm of points\n", - " total_power = np.sum(power)\n", - " norm_squared = np.linalg.norm(points)**2\n", - " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n", - " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n", - "\n", - " return np.arange(num_coefficients), power\n", - "\n", - "def interpolate(ax, points, color, continuous, alpha=1.0):\n", - " p = len(points)\n", - " if continuous:\n", - " # Perform Fourier Transform\n", - " ft = np.fft.fft(points)\n", - " \n", - " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n", - " num_coefficients = (p // 2) + 1\n", - " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n", - " \n", - " # Create a dense set of x-values for smooth interpolation\n", - " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n", - " curr_val = np.zeros(xs.shape, dtype=complex)\n", - " \n", - " # Use only non-negative frequencies for interpolation\n", - " for freq in range(num_coefficients):\n", - " theta = np.angle(ft[freq])\n", - " r = np.abs(ft[freq]) / p\n", - " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n", - " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n", - " r *= 2\n", - " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n", - "\n", - " # Plot the real part (since output is real-valued)\n", - " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n", - " else:\n", - " ax.plot(np.arange(p), points, color=color, alpha=alpha) " - ] - }, - { - "cell_type": "markdown", - "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7", - "metadata": {}, - "source": [ - "## Gradient Descent Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bcd15c5a-5745-41ba-b015-48e403160c7e", - "metadata": {}, - "outputs": [], - "source": [ - "seed = 0 # or any integer you like\n", - "random.seed(seed)\n", - "np.random.seed(seed)\n", - "torch.manual_seed(seed)\n", - "torch.cuda.manual_seed_all(seed) # if using GPU\n", - "\n", - "# Data Generation using the new function\n", - "# TEST_MODE: Reduce p and hidden_size for faster automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "p = 20 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n", - "\n", - "# Get base vector\n", - "# template = generate_template(p, 2, 1.0)\n", - "# template = one_hot(p)\n", - "template = generate_fixed_template(p)\n", - "\n", - "# Mean center template\n", - "template -= np.mean(template)\n", - "\n", - "# Generate dataset using numpy\n", - "X, Y = ModularAdditionDataset(p, template)\n", - "\n", - "# Convert to PyTorch tensors\n", - "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n", - "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n", - "\n", - "# Create a TensorDataset and DataLoader\n", - "dataset = TensorDataset(X_tensor, Y_tensor)\n", - "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n", - "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n", - "\n", - "# Initialize model\n", - "hidden_size = 6 if TEST_MODE else 6 * 3 # Reduced in test mode\n", - "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-2, output_scale=1e0)\n", - "\n", - "# Create loss function\n", - "loss = nn.MSELoss()\n", - "\n", - "# Construct optimizer\n", - "lr, mom = 0.01, 0.9\n", - "optimizer = optim.SGD(model.parameters(), lr=lr, momentum=mom)\n", - "# optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n", - "\n", - "# Train the model\n", - "# TEST_MODE: Set to reduce epochs for automated testing\n", - "import os\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "epochs = 2 if TEST_MODE else 1000001\n", - "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//100))" - ] - }, - { - "cell_type": "markdown", - "id": "eae371c4-1405-4ac5-982c-0ebacb688ed7", - "metadata": {}, - "source": [ - "## AGF Numerics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "489e82e1-61c8-43e6-b260-fd96c815dec8", - "metadata": {}, - "outputs": [], - "source": [ - "class ModsumSubNetwork(nn.Module):\n", - " \n", - " def __init__(self, d_in, d_out, init_scale):\n", - " super().__init__()\n", - " assert d_in%2 == 0\n", - " self.p = d_in // 2\n", - " self.u = nn.Linear(self.p, 1, bias=False)\n", - " self.v = nn.Linear(self.p, 1, bias=False)\n", - " self.w = nn.Linear(1, d_out, bias=False)\n", - " with torch.no_grad():\n", - " self.w.weight.mul_(init_scale)\n", - " self.u.weight.mul_(init_scale)\n", - " self.v.weight.mul_(init_scale)\n", - " self.active = False\n", - " self.util_acc = 0\n", - " self.c_a = 1/self.get_norm() - 1\n", - " \n", - " self.normalize()\n", - " \n", - " def get_norm(self):\n", - " sqnorm = lambda x: torch.linalg.norm(x.weight)**2\n", - " norm = torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n", - " return norm\n", - " \n", - " def reinitialize(self, u, v, w):\n", - " with torch.no_grad():\n", - " self.u.weight.copy_(u)\n", - " self.v.weight.copy_(v)\n", - " self.w.weight.copy_(w)\n", - " self.c_a = 1/self.get_norm() - 1\n", - " \n", - " def forward(self, x):\n", - " x1 = x[:, :self.p]\n", - " x2 = x[:, self.p:]\n", - " return self.w((self.u(x1) + self.v(x2))**2)\n", - " \n", - " def normalize(self):\n", - " norm = self.get_norm()\n", - " with torch.no_grad():\n", - " self.w.weight.div_(norm)\n", - " self.u.weight.div_(norm)\n", - " self.v.weight.div_(norm)\n", - " \n", - " def utility_step(self, x, residual, learning_rate):\n", - " f_i = self(x)\n", - " util = torch.einsum('nd,nd->n', f_i, residual).mean()\n", - " self.util_acc += 3 * learning_rate * util.item()\n", - " norm_th = 1/(1 + self.c_a - self.util_acc)\n", - " \n", - " util.backward()\n", - " with torch.no_grad():\n", - " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n", - " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n", - " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n", - " self.w.weight.grad.zero_()\n", - " self.u.weight.grad.zero_()\n", - " self.v.weight.grad.zero_()\n", - " self.normalize()\n", - "\n", - "\n", - "class ModsumNetwork(nn.Module):\n", - " \n", - " def __init__(self, d_in, d_out, init_scale, width=100):\n", - " super().__init__()\n", - " self.d_in = d_in\n", - " self.d_out = d_out\n", - " self.width = width\n", - " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n", - " self.neurons = nn.ModuleList(neurons)\n", - " self.set_mode(\"utilmax\")\n", - " \n", - " def load_init(self, U, V, W):\n", - " for i, n in enumerate(self.neurons):\n", - " u, v, w = U[i], V[i], W[i][:, None]\n", - " n.reinitialize(u, v, w)\n", - "\n", - " def dormant(self):\n", - " return [neuron for neuron in self.neurons if not neuron.active]\n", - " \n", - " def active(self):\n", - " return [neuron for neuron in self.neurons if neuron.active]\n", - "\n", - " \n", - " def set_mode(self, mode):\n", - " if mode not in [\"utilmax\", \"costmin\"]:\n", - " raise ValueError(\"mode must be utilmax or costmin\")\n", - " self.mode = mode\n", - " for neuron in self.neurons:\n", - " grad_on = (mode==\"utilmax\") ^ neuron.active\n", - " for param in neuron.parameters():\n", - " param.requires_grad = grad_on\n", - " \n", - " def forward(self, x):\n", - " if not np.any([n.active for n in self.neurons]):\n", - " return torch.zeros(x.shape[0], self.d_out)\n", - " else:\n", - " outputs = torch.stack([neuron(x) for neuron in self.neurons if neuron.active], dim=0)\n", - " return torch.sum(outputs, dim=0)\n", - "\n", - "\n", - "def train_agf(X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None, \n", - " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4):\n", - " \n", - " # Initialize\n", - " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n", - " if from_init:\n", - " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n", - " assert d_in == U.shape[1]*2\n", - " assert d_out == W.shape[1]\n", - " width = U.shape[0]\n", - " net = ModsumNetwork(d_in, d_out, init_sz, width=width)#.cuda()\n", - " net.load_init(U, V, W)\n", - " else:\n", - " net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)#.cuda()\n", - " X_train.requires_grad = False\n", - " \n", - " def update_results(results, t):\n", - " results[\"t\"].append(t)\n", - " residual = (Y_train - net(X_train))\n", - " residual = residual.detach()\n", - " results[\"residuals\"].append(residual)\n", - " loss = (residual**2).mean().item()\n", - " results[\"losses\"].append(loss)\n", - " results[\"models\"].append(net.state_dict())\n", - " results[\"pred\"].append(net(X_train).detach().cpu().clone())\n", - " \n", - " results = {\n", - " \"t\": [],\n", - " \"residuals\": [],\n", - " \"losses\": [],\n", - " \"models\": [],\n", - " \"pred\": [],\n", - " }\n", - "\n", - " t = 0\n", - " update_results(results, t)\n", - " for _ in tqdm(range(agf_steps)):\n", - " \n", - " # Utility Maximization\n", - " residual = (1/d_out) * 2*(Y_train - net(X_train))\n", - " residual = residual.detach()\n", - " iters = 0\n", - " mode = \"utilmax\"\n", - " while mode == \"utilmax\":\n", - " for n in net.neurons:\n", - " if n.active:\n", - " continue\n", - " n.utility_step(X_train, residual, utilmax_lr)\n", - " if n.util_acc > n.c_a:\n", - " n.active = True\n", - " mode = \"costmin\"\n", - " # break\n", - " iters += 1\n", - " net.set_mode(mode)\n", - " t += iters\n", - "\n", - " # Cost Minimization\n", - " optimizer = torch.optim.SGD(net.parameters(), lr=costmin_lr, momentum=0.9)\n", - " for i in range(int(costmin_maxiter)):\n", - " optimizer.zero_grad(set_to_none=False)\n", - " residual = Y_train - net(X_train)\n", - " loss = (residual ** 2).mean()\n", - " loss.backward()\n", - " optimizer.step()\n", - " net.set_mode(\"utilmax\")\n", - "\n", - " \n", - " print(f\"loss: {loss.item():.5f}\")\n", - " update_results(results, t)\n", - "\n", - " # Check for Termination\n", - " if not net.dormant() or loss.item() < loss_thresh:\n", - " break\n", - " \n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5b3eddcc-8c3e-45dd-8f57-45585a021f5d", - "metadata": {}, - "outputs": [], - "source": [ - "costmin_lr = 0.01\n", - "utilmax_lr = 0.1\n", - "results = train_agf(X_tensor, Y_tensor, init_sz=model.init_scale, agf_steps=50, from_init=param_history[0],\n", - " utilmax_lr=utilmax_lr, costmin_lr=costmin_lr,\n", - " costmin_maxiter=1e4, loss_thresh=1e-4)" - ] - }, - { - "cell_type": "markdown", - "id": "0f48aebc-a439-405a-a057-3f5c24cca91a", - "metadata": {}, - "source": [ - "## Plot Loss" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff46febe-abb5-459a-bb06-a18a26afb967", - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n", - "ax.plot(list(loss_history), lw=4)\n", - "\n", - "for lossval in results[\"losses\"]:\n", - " ax.axhline(lossval, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n", - "\n", - "f = utilmax_lr / (lr/(1-mom))\n", - "for t in results[\"t\"]:\n", - " ax.axvline(f*t, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n", - "\n", - "times = results[\"t\"] + [epochs]\n", - "AGF_losses = results[\"losses\"] + [results[\"losses\"][-1]]\n", - "ax.step(f*np.array(times), AGF_losses, where=\"post\", lw=2, ls='dashed', color=\"k\")\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "for k, alpha in enumerate(alpha_values):\n", - " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n", - "\n", - "ax.set_xscale(\"log\")\n", - "ax.set_yscale(\"log\")\n", - "ax.set_xlim(1e1, 1e6)\n", - "ax.set_ylim(1e-3, 1e0)\n", - "ax.set_xlabel('Epochs', fontsize=24)\n", - "ax.set_ylabel('Train Loss', fontsize=24)\n", - "\n", - "style_axes(ax)\n", - "plt.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "40b851e7-6256-43cd-b9f3-aca38db04917", - "metadata": {}, - "source": [ - "## Power Spectrum of output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68b25ca9-6339-49dd-9d45-577a51798a25", - "metadata": {}, - "outputs": [], - "source": [ - "# === SETTINGS ===\n", - "p = Y_tensor.shape[1]\n", - "num_freqs = p // 2 + 1\n", - "\n", - "# Compute template power spectrum\n", - "template_ft = np.fft.rfft(template)\n", - "template_power = np.abs(template_ft)[:num_freqs]\n", - "\n", - "# === Compute power spectrum of template ===\n", - "freq, power = get_power(template)\n", - "valid = power > 1e-20\n", - "freq, power = freq[valid], power[valid]\n", - "sorted_idx = np.argsort(-power)\n", - "freq, power = freq[sorted_idx], power[sorted_idx]\n", - "\n", - "# === Theory lines ===\n", - "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n", - "coef = 1 / p\n", - "theta0 = np.sqrt(2) * model.init_scale\n", - "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n", - "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n", - "step_size = 2 * coef * lr / (1 - mom)\n", - "\n", - "\n", - "# Color settings\n", - "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n", - "manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - "}\n", - "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n", - "\n", - "# Compute output power over time (GD)\n", - "num_points = 1000\n", - "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n", - "powers_over_time = []\n", - "\n", - "for step in steps:\n", - " model.load_state_dict(param_history[step])\n", - " model.eval()\n", - " with torch.no_grad():\n", - " outputs = model(X_tensor)\n", - " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", - " avg_power = np.mean(np.abs(ft), axis=0)\n", - " powers_over_time.append(avg_power)\n", - "\n", - "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n", - "\n", - "\n", - "# Compute output power over time (AGF)\n", - "f = utilmax_lr / (lr/(1-mom))\n", - "AGF_steps = results[\"t\"]\n", - "powers_over_time_AGF = []\n", - "for i, step in enumerate(AGF_steps):\n", - " outputs = results[\"pred\"][i]\n", - " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n", - " avg_power = np.mean(np.abs(ft), axis=0)\n", - " powers_over_time_AGF.append(avg_power)\n", - "powers_over_time_AGF = np.array(powers_over_time_AGF) # shape: (steps, freqs)\n", - "AGF_steps = [f * t for t in AGF_steps]\n", - "\n", - "AGF_steps.append(epochs)\n", - "powers_over_time_AGF = np.vstack([\n", - " powers_over_time_AGF,\n", - " powers_over_time_AGF[-1, :]\n", - "])\n", - "\n", - "# === PLOTTING ===\n", - "fig, ax = plt.subplots(figsize=(6, 7))\n", - "\n", - "for k in range(num_freqs):\n", - " color = colors[k]\n", - " label = fr\"$\\xi = {k}$\" if k in [1, 3, 5] else None\n", - " ax.plot(steps, powers_over_time[:, k], color=color, lw=3, label=label)\n", - " label_agf = 'AGF' if k == 10 else None\n", - " ax.step(AGF_steps, powers_over_time_AGF[:, k], color='k', lw=2, ls='dashed', where=\"post\", label=label_agf)\n", - " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n", - "\n", - "for k, tau in enumerate(tau_values):\n", - " color = colors[freq[k]]\n", - " ax.axvline(x=tau / step_size, color=color, linestyle='dashed', linewidth=2, alpha=0.5)\n", - "\n", - " # Add arrow at intersection\n", - " x = tau / step_size\n", - " y = template_power[freq[k]]\n", - " #draw an arrow from the lower bound to the right\n", - " #use default color cycle\n", - " ax.arrow(1.04 * x, y + 0.5, 1.5 * x, 0, \n", - " head_width=0.2, head_length=x*0.2, length_includes_head=True,\n", - " fc=color, ec=color, lw=4)\n", - "\n", - "# # Add vertical lines if needed\n", - "# for step in time_steps:\n", - "# ax.axvline(x=step, color='gray', alpha=0.5, linestyle='solid', linewidth=2)\n", - "\n", - "# Labeling and formatting\n", - "ax.set_xscale('log')\n", - "ax.set_xlim(5e1, 2e6)\n", - "ax.set_xticks([1000, 10000, 100000, epochs-1])\n", - "ax.set_ylabel(\"Power\", fontsize=24)\n", - "ax.set_xlabel(\"Epochs\", fontsize=24)\n", - "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper right', bbox_to_anchor=(1, 0.9), labelspacing=0.25)\n", - "\n", - "style_axes(ax)\n", - "ax.set_xticks([1000, 10000, 100000, epochs-1])\n", - "ax.grid(False)\n", - "plt.tight_layout()\n", - "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd", - "metadata": {}, - "source": [ - "## Plot outputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e333d1ab-1501-434f-86d2-82c10bb58f11", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "# Choose time steps to visualize\n", - "steps_to_show = [1000, 10000, 100000, epochs-1]\n", - "num_samples = 1 # how many examples to plot per row\n", - "p = Y_tensor.shape[1]\n", - "x = np.arange(p)\n", - "\n", - "fig, axes = plt.subplots(len(steps_to_show), 1, figsize=(6, 6), sharex=True)\n", - "\n", - "for row, step in enumerate(steps_to_show):\n", - " # Load weights at this step\n", - " model.load_state_dict(param_history[step])\n", - " model.eval()\n", - "\n", - " indices = np.random.choice(len(Y_tensor), size=num_samples, replace=False)\n", - " with torch.no_grad():\n", - " preds = model(X_tensor[indices]).detach().cpu().numpy()\n", - " truths = Y_tensor[indices].detach().cpu().numpy()\n", - "\n", - " ax = axes[row]\n", - " for i, idx in enumerate(indices):\n", - " a = idx // p\n", - " b = idx % p\n", - " label_true = r\"$(a + b) \\cdot x$\"\n", - " label_pred = r\"$f(a \\cdot x, b \\cdot x)$\"\n", - "\n", - " # Plot ground truth\n", - " interpolate(ax, truths[i], color=f\"C{i}\", alpha=0.9, continuous=True)\n", - " ax.scatter(x, truths[i], color=f\"C{i}\", s=30, alpha=0.9, label=label_true)\n", - "\n", - " # Plot prediction\n", - " interpolate(ax, preds[i], color='k', alpha=1.0, continuous=True)\n", - " ax.scatter(x, preds[i], color='k', s=30, alpha=0.7, label=label_pred)\n", - "\n", - " style_axes(ax, numyticks=3, labelsize=12)\n", - " ax.grid(False)\n", - " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(step))}}}$\", fontsize=20)\n", - "\n", - " # Only bottom row gets x-ticks\n", - " if row < len(steps_to_show) - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # ax.legend(loc='best', fontsize=12, title=fr\"$a = {a}, b = {b}$\", handlelength=0, labelspacing=0.1, title_fontsize=14, frameon=False)\n", - " ax.legend(\n", - " loc='center left',\n", - " bbox_to_anchor=(0.95, 0.5), # X slightly beyond the right edge, Y centered\n", - " fontsize=8,\n", - " title=fr\"$a = {a}, b = {b}$\",\n", - " title_fontsize=10,\n", - " handlelength=0,\n", - " labelspacing=0.1,\n", - " frameon=False\n", - " )\n", - "\n", - "# axes[-1].set_xlabel(\"Output Index\", fontsize=20)\n", - "plt.tight_layout()\n", - "plt.savefig(\"predictions.pdf\", bbox_inches='tight')" - ] - }, - { - "cell_type": "markdown", - "id": "b267424b-a0e5-47e3-9e01-1dc41e05e026", - "metadata": {}, - "source": [ - "## Plot Weights" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9de707bf-838e-4384-8150-3d8fe4586fc3", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "import matplotlib.gridspec as gridspec\n", - "\n", - "# Steps and corresponding highlighted frequencies\n", - "\n", - "steps = [1000, 10000, 100000, epochs-1]\n", - "highlight_freqs_list = [[], [1], [3], [5]]\n", - "\n", - "num_rows, num_cols = len(steps), 3\n", - "\n", - "# Use gridspec to control layout\n", - "fig = plt.figure(figsize=(24, 6), constrained_layout=True)\n", - "gs = gridspec.GridSpec(num_rows, num_cols, width_ratios=[1.1, 1.1, 2.0], wspace=0.1, hspace=0.1)\n", - "axes = np.empty((num_rows, num_cols), dtype=object)\n", - "\n", - "# Create axes\n", - "for row in range(num_rows):\n", - " for col in range(num_cols):\n", - " if col == 2:\n", - " ax = fig.add_subplot(gs[row, col], projection='polar')\n", - " else:\n", - " ax = fig.add_subplot(gs[row, col]) # ⬅ no sharex anymore\n", - " axes[row, col] = ax\n", - "\n", - "num_freqs = None\n", - "for row, index in enumerate(steps):\n", - " highlight_freqs = highlight_freqs_list[row]\n", - " params = param_history[index]\n", - " W = params['W'].numpy()\n", - " h, p = W.shape\n", - "\n", - " if num_freqs is None:\n", - " num_freqs = p // 2 + 1\n", - " cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n", - " colors = [cmap(i) for i in range(num_freqs)]\n", - " manual_colors = {\n", - " 0: 'tab:blue',\n", - " 1: 'tab:orange',\n", - " 2: 'tab:red',\n", - " 3: 'tab:green',\n", - " 4: 'tab:brown',\n", - " 5: 'tab:purple',\n", - " }\n", - " freq_colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n", - "\n", - "\n", - " # === Column 1: Weights ===\n", - " ax = axes[row, 0]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " dom_idx = np.argmax(power)\n", - " color = freq_colors[dom_idx]\n", - " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n", - " x = np.linspace(0, p - 1, 500)\n", - " interpolate(ax, w, color=color, continuous=True, alpha=alpha)\n", - " ax.scatter(np.arange(p), w, color=color, s=10, alpha=alpha)\n", - " if row == 0: ax.set_title(\"Weights\", fontsize=24)\n", - " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(index))}}}$\", fontsize=20)\n", - " style_axes(ax, numyticks=3, numxticks=5, labelsize=12)\n", - " ax.grid(False)\n", - " if row < num_rows - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # === Column 2: Frequency Spectrum ===\n", - " ax = axes[row, 1]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " for k in range(len(power)):\n", - " color = freq_colors[k]\n", - " ax.vlines(k, 0, power[k], linewidth=4, color=color, alpha=0.4)\n", - " ax.scatter(k, power[k], color=color, s=50, alpha=0.7)\n", - " # ax.axhline(0, color='gray', linewidth=1, linestyle='--', alpha=0.4)\n", - " ax.set_xlim(-0.5, len(power) - 0.5)\n", - " ax.set_xticks(np.arange(len(power)))\n", - " if row == 0: ax.set_title(\"Frequency\", fontsize=24)\n", - " style_axes(ax, numyticks=3, numxticks=11, labelsize=12)\n", - " ax.grid(False)\n", - " if row < num_rows - 1:\n", - " ax.tick_params(labelbottom=False)\n", - "\n", - " # === Column 3: Phase Polar Plot ===\n", - " ax = axes[row, 2]\n", - " for i in range(h):\n", - " w = W[i, :]\n", - " ft = np.fft.rfft(w)\n", - " power = np.abs(ft)**2\n", - " dom_idx = np.argmax(power)\n", - " phase = np.angle(ft[dom_idx])\n", - " norm = np.linalg.norm(w)\n", - " color = freq_colors[dom_idx]\n", - " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n", - " ax.plot([phase, phase], [0, norm], color=color, linewidth=2, alpha=alpha)\n", - " ax.scatter(phase, norm, color=color, s=40, alpha=alpha)\n", - " angles = np.arange(0, 360, 45)\n", - " # ax.set_thetagrids(angles, [f\"{a}°\" if a in [45,135,225,315] else \"\" for a in angles])\n", - " ax.set_thetagrids(angles, [\"\" for a in angles])\n", - " ax.set_yticklabels([])\n", - " ax.spines['polar'].set_linewidth(2)\n", - " if row == 0: ax.set_title(\"Phase\", fontsize=24)\n", - "\n", - "# Shift polar plots left to reduce whitespace\n", - "for row in range(num_rows):\n", - " ax = axes[row, 2]\n", - " pos = ax.get_position()\n", - " ax.set_position([pos.x0 - 0.155, pos.y0, pos.width, pos.height])\n", - "\n", - "plt.savefig(\"W-weights.pdf\", bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f5f56f9-7055-4056-9a18-7d91b3be50f8", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "rubiks", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/paper_figures.ipynb b/notebooks/paper_figures.ipynb deleted file mode 100644 index 696c450..0000000 --- a/notebooks/paper_figures.ipynb +++ /dev/null @@ -1,268 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "249dbef8", - "metadata": {}, - "source": [ - "# Create Figures for Group AGF Paper" - ] - }, - { - "cell_type": "markdown", - "id": "58a60e4e", - "metadata": {}, - "source": [ - "For each group, we should have 4 figures that make up a panel. We can organize these figures in Adobe Illustrator for the final figures.\n", - "- Loss plot with clear drops\n", - "- Model output power over epoch\n", - "- Weights over epoch (ideally one weight per irreps, but for MNIST just choose ~5.)\n", - "- Output prediction vs target over epoch" - ] - }, - { - "cell_type": "markdown", - "id": "3fd72af4", - "metadata": {}, - "source": [ - "Action Plan:\n", - "- Load Model\n", - "- recreate power and loss figures\n", - "\n", - "Weights figure\n", - "- go through weights and label them with their predomonant frequency\n", - "- Choose 1 weight per frequency, and plot these over time. In AI outline these weights with the same color as their power\n", - "- in the loss plot, try to make the color of the line (starting at the drop) match the color of the power and the irreps\n", - "\n", - "Output figure\n", - "- detect jumps in power\n", - "- right after each jump, save epoch value for label\n", - "- at that epoch, plot Y[i] and model[X[i]] output." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ff57238", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random\n", - "import torch\n", - "import os\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import shutil\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "from matplotlib.animation import FuncAnimation\n", - "from matplotlib.ticker import FormatStrFormatter\n", - "from matplotlib.ticker import FuncFormatter\n", - "from matplotlib.ticker import MaxNLocator\n", - "\n", - "import importlib\n", - "import pickle\n", - "\n", - "import group_agf.binary_action_learning.models as models\n", - "import group_agf.binary_action_learning.datasets as datasets\n", - "import group_agf.binary_action_learning.power as power\n", - "import group_agf.binary_action_learning.plot as plot\n", - "\n", - "from escnn.group import *" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9fef4df", - "metadata": {}, - "outputs": [], - "source": [ - "# parameters\n", - "config = {\n", - " \"model_save_dir\": \"/tmp/adele/\",\n", - " \"dataset_fraction\": 0.3, # fraction of the total dataset to train on\n", - " \"group_name\": 'dihedral', # 'dihedral', 'cnxcn', 'octahedral'\n", - " \"mnist_digit\": 4,\n", - " \"group_n\": 4, # n in Dn [3, 4, 5]\n", - " \"image_length\": 10, # length of one side of the square image patch\n", - " # Learning Parameters\n", - " \"seed\": 10,\n", - " \"init_scale\": 1e-2,\n", - " \"lr\": 0.001,\n", - " \"mom\": 0.9,\n", - " # Training parameters\n", - " \"epochs\": 5000,\n", - " \"verbose_interval\": 100,\n", - " \"batch_size\": 128,\n", - " \"frequencies_to_learn\": 3,\n", - " \"run_start_time\": \"10-31_14-42-46\",\n", - "}\n", - "\n", - "config[\"group\"] = DihedralGroup(config[\"group_n\"])\n", - "config[\"group_size\"] = config[\"group\"].order()\n", - "\n", - "print(f\"Group name: {config['group_name']}, group size: {config['group_size']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa742e82", - "metadata": {}, - "outputs": [], - "source": [ - "from pyexpat import model\n", - "\n", - "\n", - "model_save_path = models.get_model_save_path(config)\n", - "print(model_save_path)\n", - "# model_save_path = '/tmp/adele/model_group_namedihedral_group_size8_frac0.3_init0.01_lr0.01_mom0.9_bs128_epochs5000_seed10_run_start10-31_14-25-32.pkl'\n", - "# print(model_save_path)\n", - "\n", - "# Load training history and model parameters from the specified path\n", - "with open(model_save_path, 'rb') as f:\n", - " training_history = pickle.load(f)\n", - "\n", - "loss_history = training_history[\"loss_history\"]\n", - "accuracy_history = training_history[\"accuracy_history\"]\n", - "param_history = training_history[\"param_history\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14393dea", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y, template = datasets.load_dataset(config)\n", - "template_power = power.GroupPower(template, config['group'])\n", - "X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, device=None)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78975e86", - "metadata": {}, - "outputs": [], - "source": [ - "# Create the model instance (matching the one that was trained)\n", - "model = models.TwoLayerNet(\n", - " group_size=config['group_size'],\n", - " hidden_size=None, # uses default if not provided\n", - " nonlinearity='square', # adjust if needed according to training\n", - " init_scale=config['init_scale'],\n", - " output_scale=1.0\n", - ").to(device)\n", - "\n", - "# Restore the model parameters from the last saved point in param_history\n", - "final_params = param_history[-1]\n", - "with torch.no_grad():\n", - " model.U.copy_(final_params['U'])\n", - " model.V.copy_(final_params['V'])\n", - " model.W.copy_(final_params['W'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a1723c01", - "metadata": {}, - "outputs": [], - "source": [ - "power_over_training_plot, freq_colors = plot.plot_training_power_over_time(\n", - " template_power, \n", - " model, \n", - " device, \n", - " param_history, \n", - " X, \n", - " config['group_name'], \n", - " save_path=None, \n", - " show=False,\n", - " return_freq_colors=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "89459141", - "metadata": {}, - "outputs": [], - "source": [ - "loss_plot = plot.plot_loss_curve(loss_history, template_power, show=False, freq_colors=freq_colors)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "86b2dd11", - "metadata": {}, - "outputs": [], - "source": [ - "irreps_plot = plot.plot_irreps(config['group'], config['group_name'], show=False)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d8be8b8", - "metadata": {}, - "outputs": [], - "source": [ - "neuron_indices = list(range(config['group_size'] * config['frequencies_to_learn']))\n", - "neuron_weights_plot = plot.plot_neuron_weights(\n", - " config['group_name'],\n", - " config['group'],\n", - " model,\n", - " config['group_size'],\n", - " neuron_indices=neuron_indices,\n", - " show=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31eec909", - "metadata": {}, - "outputs": [], - "source": [ - "model_predictions_plot = plot.plot_model_outputs(config['group_name'], config['group_size'], model, X, Y, idx=13, step=1, show=True) \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3a877866", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/rnn_gagf.ipynb b/notebooks/rnn_gagf.ipynb deleted file mode 100644 index fb9f402..0000000 --- a/notebooks/rnn_gagf.ipynb +++ /dev/null @@ -1,2112 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "af291059", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import yaml\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pathlib import Path\n", - "import seaborn as sns\n", - "\n", - "def load_sweep_results_grid(sweep_dir: str, k_values: list, hidden_dims: list):\n", - " \"\"\"\n", - " Load sweep results and organize into a grid for heatmap visualization.\n", - " \n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - " \n", - " Returns:\n", - " grid: 2D numpy array with shape (len(hidden_dims), len(k_values))\n", - " containing mean final train losses\n", - " std_grid: 2D numpy array with standard deviations (if multiple seeds)\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - " \n", - " # Initialize grids\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " \n", - " # Load results for each experiment\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - " \n", - " if not exp_dir.exists():\n", - " print(f\"Warning: Experiment {exp_name} not found\")\n", - " continue\n", - " \n", - " # Load experiment summary\n", - " summary_file = exp_dir / \"experiment_summary.yaml\"\n", - " if summary_file.exists():\n", - " with open(summary_file, 'r') as f:\n", - " summary = yaml.safe_load(f)\n", - " \n", - " # Get mean train loss\n", - " if 'train_loss_stats' in summary:\n", - " grid[i, j] = summary['train_loss_stats']['mean']\n", - " std_grid[i, j] = summary['train_loss_stats']['std']\n", - " else:\n", - " print(f\"Warning: No train_loss_stats in {exp_name}\")\n", - " else:\n", - " print(f\"Warning: No summary file for {exp_name}\")\n", - " \n", - " return grid, std_grid\n" - ] - }, - { - "cell_type": "markdown", - "id": "c433cb4d", - "metadata": {}, - "source": [ - "## 1D Analysis Functions\n", - "\n", - "Analyze individual 1D experiments from the sweep with detailed power spectrum and neuron specialization plots.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "95df6861", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "from pathlib import Path\n", - "from gagf.rnns.utils import (\n", - " plot_prediction_power_spectrum_over_time_1d,\n", - " plot_wout_neuron_specialization_1d,\n", - " plot_model_predictions_over_time_1d,\n", - " topk_template_freqs_1d,\n", - ")\n", - "from gagf.rnns.model import SequentialMLP\n", - "\n", - "def analyze_1d_experiment(sweep_dir, exp_name, seed=0, num_freqs_to_track=10):\n", - " \"\"\"\n", - " Analyze a single 1D experiment from the sweep.\n", - " \n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " exp_name: Experiment name (e.g., \"k3_h360\")\n", - " seed: Seed number to analyze\n", - " num_freqs_to_track: Number of top frequencies to track\n", - " \n", - " Returns:\n", - " Dictionary with analysis results\n", - " \"\"\"\n", - " # Setup paths\n", - " sweep_path = Path(sweep_dir)\n", - " exp_dir = sweep_path / exp_name / f\"seed_{seed}\"\n", - " \n", - " if not exp_dir.exists():\n", - " print(f\"Experiment directory not found: {exp_dir}\")\n", - " return None\n", - " \n", - " print(f\"Analyzing: {exp_name}, seed {seed}\")\n", - " print(f\"Directory: {exp_dir}\")\n", - " \n", - " # Load config\n", - " with open(exp_dir / \"config.yaml\", 'r') as f:\n", - " config = yaml.safe_load(f)\n", - " \n", - " # Load template\n", - " template = np.load(exp_dir / \"template.npy\")\n", - " p = len(template)\n", - " k = config['data']['k']\n", - " hidden_dim = config['model']['hidden_dim']\n", - " \n", - " print(f\" p={p}, k={k}, hidden_dim={hidden_dim}\")\n", - " \n", - " # Load training history\n", - " train_loss_hist = np.load(exp_dir / \"train_loss_history.npy\")\n", - " param_hist = torch.load(exp_dir / \"param_history.pt\", map_location='cpu')\n", - " \n", - " # Create model\n", - " device = 'cpu'\n", - " template_torch = torch.tensor(template, dtype=torch.float32, device=device)\n", - " model = SequentialMLP(\n", - " p=p,\n", - " d=hidden_dim,\n", - " template=template_torch,\n", - " k=k,\n", - " init_scale=config['model']['init_scale'],\n", - " return_all_outputs=config['model']['return_all_outputs'],\n", - " ).to(device)\n", - " \n", - " # Generate evaluation data\n", - " from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_1d\n", - " X_data, Y_data, _ = build_modular_addition_sequence_dataset_1d(\n", - " p, template, k,\n", - " mode='sampled',\n", - " num_samples=1000,\n", - " return_all_outputs=config['model']['return_all_outputs'],\n", - " )\n", - " X_data_t = torch.tensor(X_data, dtype=torch.float32, device=device)\n", - " Y_data_t = torch.tensor(Y_data, dtype=torch.float32, device=device)\n", - " \n", - " # Get tracked frequencies\n", - " tracked_freqs = topk_template_freqs_1d(template, K=num_freqs_to_track)\n", - " colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n", - " \n", - " # Checkpoints to analyze\n", - " checkpoint_indices = [0, len(param_hist)//4, len(param_hist)//2, \n", - " 3*len(param_hist)//4, len(param_hist)-1]\n", - " \n", - " # Plot 1: Power spectrum over time\n", - " print(\"\\\\n Plotting power spectrum analysis...\")\n", - " fig1, _, _, _ = plot_prediction_power_spectrum_over_time_1d(\n", - " model, param_hist, X_data_t, Y_data_t, template, p,\n", - " loss_history=train_loss_hist,\n", - " num_freqs_to_track=num_freqs_to_track,\n", - " num_samples=100,\n", - " save_path=exp_dir / \"power_spectrum_analysis_1d.pdf\",\n", - " show=True\n", - " )\n", - " \n", - " # Plot 2: Model predictions over time\n", - " print(\" Plotting predictions over time...\")\n", - " fig2, _ = plot_model_predictions_over_time_1d(\n", - " model, param_hist, X_data_t, Y_data_t, p,\n", - " steps=checkpoint_indices,\n", - " save_path=exp_dir / \"predictions_over_time_1d.pdf\",\n", - " show=True\n", - " )\n", - " \n", - " # Plot 3: W_out neuron specialization\n", - " print(\" Plotting W_out neuron specialization...\")\n", - " figs3 = plot_wout_neuron_specialization_1d(\n", - " param_hist, tracked_freqs, colors, p,\n", - " steps=checkpoint_indices,\n", - " dead_thresh_l2=0.25,\n", - " save_dir=exp_dir,\n", - " show=True\n", - " )\n", - " \n", - " print(\"\\\\n ✓ Analysis complete!\")\n", - " \n", - " return {\n", - " 'config': config,\n", - " 'template': template,\n", - " 'train_loss': train_loss_hist,\n", - " 'tracked_freqs': tracked_freqs,\n", - " }\n", - "\n", - "# Example usage:\n", - "# sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251202_XXXXXX\"\n", - "# result = analyze_1d_experiment(sweep_dir, \"k3_h360\", seed=0)\n" - ] - }, - { - "cell_type": "markdown", - "id": "7bc3db90", - "metadata": {}, - "source": [ - "# Analyze RNNs trained on GAGF sequential task" - ] - }, - { - "cell_type": "markdown", - "id": "11fb7c9b", - "metadata": {}, - "source": [ - "## Set up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43581ce4", - "metadata": {}, - "outputs": [], - "source": [ - "# autoreload\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "# jupyter black formatter\n", - "%load_ext jupyter_black\n", - "\n", - "import subprocess\n", - "import os\n", - "import sys\n", - "\n", - "gitroot_path = subprocess.check_output(\n", - " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n", - ").strip()\n", - "\n", - "os.chdir(gitroot_path)\n", - "print(\"Working directory: \", os.getcwd())\n", - "\n", - "if gitroot_path not in sys.path:\n", - " sys.path.insert(0, gitroot_path)\n", - "print(\"Directory added to path: \", gitroot_path)" - ] - }, - { - "cell_type": "markdown", - "id": "f0407a17", - "metadata": {}, - "source": [ - "## Sequence-to-sequence sweep across different values of k (sequence length)" - ] - }, - { - "cell_type": "markdown", - "id": "070e8c55", - "metadata": {}, - "source": [ - "### Loss curves" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf050abb", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pathlib import Path\n", - "from typing import Dict, List, Optional\n", - "\n", - "\n", - "def get_sweep_experiments(sweep_dir: str) -> List[str]:\n", - " \"\"\"\n", - " Get all experiment names from a sweep directory.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - "\n", - " Returns:\n", - " List of experiment names (subdirectories with seed_0)\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - " experiments = []\n", - "\n", - " for item in sweep_path.iterdir():\n", - " if (\n", - " item.is_dir()\n", - " and not item.name.startswith(\".\")\n", - " and item.name not in [\"configs\"]\n", - " ):\n", - " # Check if it has a seed_0 subdirectory\n", - " if (item / \"seed_0\").exists():\n", - " experiments.append(item.name)\n", - "\n", - " return sorted(experiments)\n", - "\n", - "\n", - "def load_experiment_losses(\n", - " sweep_dir: str, experiment_name: str, seed: int = 0\n", - ") -> Dict[str, np.ndarray]:\n", - " \"\"\"\n", - " Load training and validation loss histories for an experiment.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " experiment_name: Name of the experiment subdirectory\n", - " seed: Seed number (default: 0)\n", - "\n", - " Returns:\n", - " Dictionary with 'train' and 'val' loss arrays (if they exist)\n", - " \"\"\"\n", - " exp_path = Path(sweep_dir) / experiment_name / f\"seed_{seed}\"\n", - " losses = {}\n", - "\n", - " train_loss_path = exp_path / \"train_loss_history.npy\"\n", - " if train_loss_path.exists():\n", - " losses[\"train\"] = np.load(train_loss_path)\n", - "\n", - " val_loss_path = exp_path / \"val_loss_history.npy\"\n", - " if val_loss_path.exists():\n", - " losses[\"val\"] = np.load(val_loss_path)\n", - "\n", - " return losses\n", - "\n", - "\n", - "def load_all_sweep_losses(\n", - " sweep_dir: str, seed: int = 0\n", - ") -> Dict[str, Dict[str, np.ndarray]]:\n", - " \"\"\"\n", - " Load loss histories for all experiments in a sweep.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " seed: Seed number (default: 0)\n", - "\n", - " Returns:\n", - " Dictionary mapping experiment names to their loss dictionaries\n", - " \"\"\"\n", - " experiments = get_sweep_experiments(sweep_dir)\n", - " all_losses = {}\n", - "\n", - " for exp_name in experiments:\n", - " all_losses[exp_name] = load_experiment_losses(sweep_dir, exp_name, seed)\n", - "\n", - " return all_losses\n", - "\n", - "\n", - "def remove_outliers_local(loss_history, window=10, threshold=3.0):\n", - " \"\"\"\n", - " Replace outliers with local median if they deviate too much.\n", - "\n", - " Args:\n", - " loss_history: Array of loss values\n", - " window: Window size for local statistics\n", - " threshold: How many local standard deviations to consider an outlier\n", - "\n", - " Returns:\n", - " Tuple of (cleaned loss history, whether any outliers were found)\n", - " \"\"\"\n", - " cleaned = loss_history.copy()\n", - " half_window = window // 2\n", - " outliers_found = False\n", - "\n", - " for i in range(len(loss_history)):\n", - " start = max(0, i - half_window)\n", - " end = min(len(loss_history), i + half_window + 1)\n", - " local_window = loss_history[start:end]\n", - "\n", - " local_median = np.median(local_window)\n", - " local_std = np.std(local_window)\n", - "\n", - " # If the value is too far from local median, replace it\n", - " if abs(loss_history[i] - local_median) > threshold * local_std:\n", - " cleaned[i] = local_median\n", - " outliers_found = True\n", - "\n", - " return cleaned, outliers_found" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b5653c8", - "metadata": {}, - "outputs": [], - "source": [ - "# Set up your sweep directory\n", - "sweep_dir = \"/home/facosta/group-agf/sweeps/seq_seq_sweep_20251113_120513\"\n", - "\n", - "# Get all experiments in the sweep\n", - "experiments = get_sweep_experiments(sweep_dir)\n", - "print(f\"Found experiments: {experiments}\")\n", - "\n", - "# Load losses for all experiments\n", - "all_losses = load_all_sweep_losses(sweep_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fa6b246a", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_loss_comparison(\n", - " sweep_dir: str,\n", - " experiments: Optional[List[str]] = None,\n", - " loss_type: str = \"train\",\n", - " log_scale: bool = True,\n", - " figsize: tuple = (10, 6),\n", - " seed: int = 0,\n", - " remove_outliers: bool = False,\n", - " outlier_window: int = 10,\n", - " outlier_threshold: float = 3.0,\n", - " template_2d: Optional[np.ndarray] = None,\n", - " p1: Optional[int] = None,\n", - " p2: Optional[int] = None,\n", - " show_theory_bands: bool = True,\n", - " num_theory_lines: Optional[int] = None,\n", - " color_by_k: bool = True,\n", - " cmap: str = \"viridis\",\n", - "):\n", - " \"\"\"\n", - " Plot and compare loss curves from multiple experiments.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " experiments: List of experiment names to plot (None = all experiments)\n", - " loss_type: 'train' or 'val'\n", - " log_scale: Whether to use log scale for both axes\n", - " figsize: Figure size tuple\n", - " seed: Seed number (default: 0)\n", - " remove_outliers: Whether to remove outliers using local outlier replacement\n", - " outlier_window: Window size for outlier detection (default: 10)\n", - " outlier_threshold: Threshold in standard deviations for outlier detection (default: 3.0)\n", - " template_2d: Optional 2D template array for computing theory lines\n", - " p1: First dimension of template (required if template_2d is provided)\n", - " p2: Second dimension of template (required if template_2d is provided)\n", - " show_theory_bands: Whether to show colored bands between theory lines (default: True)\n", - " num_theory_lines: Number of theory lines to show (default: None = show all)\n", - " color_by_k: Whether to color lines by k value (default: True)\n", - " cmap: Colormap name for k-based coloring (default: 'viridis')\n", - " \"\"\"\n", - " if experiments is None:\n", - " experiments = get_sweep_experiments(sweep_dir)\n", - "\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - "\n", - " # Compute theory lines if template is provided\n", - " theory_levels = None\n", - " if template_2d is not None:\n", - " if p1 is None or p2 is None:\n", - " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n", - "\n", - " # Import the helper function (assuming it's in utils.py)\n", - " from gagf.rnns.utils import get_power_2d_adele\n", - "\n", - " # Compute power spectrum of template\n", - " _, _, power = get_power_2d_adele(template_2d)\n", - " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - "\n", - " # Theory levels (cumulative tail sums)\n", - " alpha_values = np.array(\n", - " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n", - " )\n", - " coef = 1.0 / (p1 * p2)\n", - " theory_levels = coef * alpha_values # strictly decreasing\n", - "\n", - " # Limit number of lines if specified\n", - " if num_theory_lines is not None:\n", - " theory_levels = theory_levels[: num_theory_lines + 1]\n", - "\n", - " # Generate colors for bands\n", - " n_bands = len(theory_levels) - 1\n", - " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n", - "\n", - " # Draw colored bands between theory lines\n", - " if show_theory_bands and n_bands > 0:\n", - " for i in range(n_bands):\n", - " y_top = theory_levels[i]\n", - " y_bot = theory_levels[i + 1]\n", - " ax.axhspan(\n", - " y_bot,\n", - " y_top,\n", - " facecolor=colors[i % len(colors)],\n", - " alpha=0.15,\n", - " zorder=-3,\n", - " )\n", - "\n", - " # Draw the black theory lines\n", - " for y in theory_levels:\n", - " ax.axhline(\n", - " y=y,\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", - " linewidth=1.5,\n", - " alpha=0.7,\n", - " zorder=-2,\n", - " label=\"_nolegend_\",\n", - " )\n", - "\n", - " # Extract k values and set up colormap\n", - " k_values = {}\n", - " if color_by_k:\n", - " import re\n", - "\n", - " for exp_name in experiments:\n", - " # Try to extract k value from experiment name (e.g., \"k_2_seqseq\" -> 2)\n", - " match = re.search(r\"k[_\\s]*(\\d+)\", exp_name, re.IGNORECASE)\n", - " if match:\n", - " k_values[exp_name] = int(match.group(1))\n", - "\n", - " if k_values:\n", - " k_min = min(k_values.values())\n", - " k_max = max(k_values.values())\n", - " norm = plt.cm.colors.Normalize(vmin=k_min, vmax=k_max)\n", - " colormap = plt.cm.get_cmap(cmap)\n", - " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n", - "\n", - " # Plot loss curves\n", - " for exp_name in experiments:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " loss_history = losses[loss_type]\n", - "\n", - " # Apply outlier removal if requested\n", - " if remove_outliers:\n", - " loss_history, outliers_found = remove_outliers_local(\n", - " loss_history, window=outlier_window, threshold=outlier_threshold\n", - " )\n", - " if outliers_found:\n", - " print(f\"Outliers from {exp_name} removed for plot\")\n", - "\n", - " # Determine color\n", - " if color_by_k and exp_name in k_values:\n", - " color = scalar_map.to_rgba(k_values[exp_name])\n", - " else:\n", - " color = None # Use default color cycle\n", - "\n", - " ax.plot(loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color)\n", - "\n", - " ax.set_xlabel(\"Step\", fontsize=14)\n", - " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n", - " title = f\"{loss_type.capitalize()} Loss Comparison - {Path(sweep_dir).name}\"\n", - " if remove_outliers:\n", - " title += \" (outliers removed)\"\n", - " ax.set_title(title, fontsize=14)\n", - " ax.legend(fontsize=10)\n", - " ax.grid(True, alpha=0.3)\n", - "\n", - " if log_scale:\n", - " ax.set_xscale(\"log\")\n", - " ax.set_yscale(\"log\")\n", - "\n", - " # Add colorbar if coloring by k\n", - " if color_by_k and k_values:\n", - " cbar = plt.colorbar(scalar_map, ax=ax, label=\"k (sequence length)\", pad=0.02)\n", - " cbar.ax.tick_params(labelsize=10)\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - " return fig, ax, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8e3f5e3", - "metadata": {}, - "outputs": [], - "source": [ - "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "plt.imshow(template_2d)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd4bf1a8", - "metadata": {}, - "outputs": [], - "source": [ - "# First, load the template from one of your experiments\n", - "# (assuming they all use the same template)\n", - "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "# Plot with theory lines\n", - "fig, ax, theory_levels = plot_loss_comparison(\n", - " sweep_dir,\n", - " template_2d=template_2d,\n", - " p1=p1,\n", - " p2=p2,\n", - " remove_outliers=True,\n", - " num_theory_lines=10, # Show first 10 theory lines\n", - " log_scale=True,\n", - " cmap=\"viridis\",\n", - ")\n", - "\n", - "# Print the theory levels (plateau values)\n", - "print(\"Theory plateau levels:\")\n", - "for i, level in enumerate(theory_levels):\n", - " print(f\" Plateau {i}: {level:.6e}\")" - ] - }, - { - "cell_type": "markdown", - "id": "ead6933c", - "metadata": {}, - "source": [ - "### Time to reach plateau" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e1d02b4", - "metadata": {}, - "outputs": [], - "source": [ - "def calculate_time_to_plateau(\n", - " sweep_dir: str,\n", - " template_2d: np.ndarray,\n", - " p1: int,\n", - " p2: int,\n", - " target_plateau_idx: int = 1,\n", - " loss_type: str = \"train\",\n", - " seed: int = 0,\n", - " tolerance: float = 1.1,\n", - " experiments: Optional[List[str]] = None,\n", - ") -> Dict[str, int]:\n", - " \"\"\"\n", - " Calculate the step at which each experiment's loss reaches a target plateau.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " template_2d: 2D template array for computing theory lines\n", - " p1, p2: Dimensions of template\n", - " target_plateau_idx: Index of target plateau (1 = first drop, 2 = second drop, etc.)\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - " tolerance: Multiplier for plateau threshold (loss must be <= tolerance * plateau_level)\n", - " experiments: List of experiment names (None = all experiments)\n", - "\n", - " Returns:\n", - " Dictionary mapping experiment names to step numbers\n", - " \"\"\"\n", - " from gagf.rnns.utils import get_power_2d_adele\n", - "\n", - " # Compute theory levels\n", - " _, _, power = get_power_2d_adele(template_2d)\n", - " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - " alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n", - " coef = 1.0 / (p1 * p2)\n", - " theory_levels = coef * alpha_values\n", - "\n", - " if target_plateau_idx >= len(theory_levels):\n", - " raise ValueError(\n", - " f\"target_plateau_idx {target_plateau_idx} exceeds available plateaus ({len(theory_levels)})\"\n", - " )\n", - "\n", - " target_level = theory_levels[target_plateau_idx]\n", - " threshold = tolerance * target_level\n", - "\n", - " if experiments is None:\n", - " experiments = get_sweep_experiments(sweep_dir)\n", - "\n", - " times_to_plateau = {}\n", - "\n", - " for exp_name in experiments:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " loss_history = losses[loss_type]\n", - "\n", - " # Find first step where loss drops below threshold\n", - " crossing_indices = np.where(loss_history <= threshold)[0]\n", - "\n", - " if len(crossing_indices) > 0:\n", - " times_to_plateau[exp_name] = crossing_indices[0]\n", - " else:\n", - " times_to_plateau[exp_name] = None # Never reached\n", - "\n", - " return times_to_plateau, theory_levels\n", - "\n", - "\n", - "def plot_time_to_plateau(\n", - " sweep_dir: str,\n", - " template_2d: np.ndarray,\n", - " p1: int,\n", - " p2: int,\n", - " target_plateau_idx: int = 1,\n", - " loss_type: str = \"train\",\n", - " seed: int = 0,\n", - " tolerance: float = 1.1,\n", - " experiments: Optional[List[str]] = None,\n", - " figsize: tuple = (10, 6),\n", - " sort_by: str = \"time\",\n", - "):\n", - " \"\"\"\n", - " Plot the time to reach a target plateau for different experiments.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " template_2d: 2D template array\n", - " p1, p2: Template dimensions\n", - " target_plateau_idx: Which plateau to measure (1 = first drop)\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - " tolerance: Multiplier for plateau threshold\n", - " experiments: List of experiment names (None = all)\n", - " figsize: Figure size\n", - " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n", - " \"\"\"\n", - " times, theory_levels = calculate_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx,\n", - " loss_type,\n", - " seed,\n", - " tolerance,\n", - " experiments,\n", - " )\n", - "\n", - " # Filter out experiments that never reached the plateau\n", - " reached = {k: v for k, v in times.items() if v is not None}\n", - " not_reached = [k for k, v in times.items() if v is None]\n", - "\n", - " if not reached:\n", - " print(\"No experiments reached the target plateau!\")\n", - " return\n", - "\n", - " # Sort experiments\n", - " if sort_by == \"time\":\n", - " sorted_items = sorted(reached.items(), key=lambda x: x[1])\n", - " else: # alphabetical\n", - " sorted_items = sorted(reached.items(), key=lambda x: x[0])\n", - "\n", - " exp_names = [item[0] for item in sorted_items]\n", - " steps = [item[1] for item in sorted_items]\n", - "\n", - " # Create bar plot\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n", - "\n", - " # Color bars by value (gradient)\n", - " colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(steps)))\n", - " for bar, color in zip(bars, colors):\n", - " bar.set_color(color)\n", - "\n", - " ax.set_xticks(range(len(exp_names)))\n", - " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n", - " ax.set_xlabel(\"Experiment\", fontsize=12)\n", - " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n", - " ax.set_title(\n", - " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n", - " fontsize=13,\n", - " )\n", - " ax.grid(True, alpha=0.3, axis=\"y\")\n", - "\n", - " # Add value labels on bars\n", - " for i, (bar, step) in enumerate(zip(bars, steps)):\n", - " height = bar.get_height()\n", - " ax.text(\n", - " bar.get_x() + bar.get_width() / 2.0,\n", - " height,\n", - " f\"{int(step):,}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", - " fontsize=9,\n", - " )\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # Print summary\n", - " print(\n", - " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n", - " )\n", - " print(\"-\" * 60)\n", - " for name, step in sorted_items:\n", - " print(f\" {name:30s}: {step:8,} steps\")\n", - "\n", - " if not_reached:\n", - " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n", - " for name in not_reached:\n", - " print(f\" - {name}\")\n", - "\n", - " plt.show()\n", - "\n", - " return fig, ax, times, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5b266d4", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_time_to_plateau(\n", - " sweep_dir: str,\n", - " template_2d: np.ndarray,\n", - " p1: int,\n", - " p2: int,\n", - " target_plateau_idx: int = 1,\n", - " loss_type: str = \"train\",\n", - " seed: int = 0,\n", - " tolerance: float = 1.1,\n", - " experiments: Optional[List[str]] = None,\n", - " figsize: tuple = (10, 6),\n", - " sort_by: str = \"time\",\n", - " color_by_k: bool = True,\n", - " cmap: str = \"viridis\",\n", - "):\n", - " \"\"\"\n", - " Plot the time to reach a target plateau for different experiments.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " template_2d: 2D template array\n", - " p1, p2: Template dimensions\n", - " target_plateau_idx: Which plateau to measure (1 = first drop)\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - " tolerance: Multiplier for plateau threshold\n", - " experiments: List of experiment names (None = all)\n", - " figsize: Figure size\n", - " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n", - " color_by_k: Whether to color bars by k value (default: True)\n", - " cmap: Colormap name for k-based coloring (default: 'viridis')\n", - " \"\"\"\n", - " times, theory_levels = calculate_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx,\n", - " loss_type,\n", - " seed,\n", - " tolerance,\n", - " experiments,\n", - " )\n", - "\n", - " # Filter out experiments that never reached the plateau\n", - " reached = {k: v for k, v in times.items() if v is not None}\n", - " not_reached = [k for k, v in times.items() if v is None]\n", - "\n", - " if not reached:\n", - " print(\"No experiments reached the target plateau!\")\n", - " return\n", - "\n", - " # Sort experiments\n", - " if sort_by == \"time\":\n", - " sorted_items = sorted(reached.items(), key=lambda x: x[1])\n", - " else: # alphabetical\n", - " sorted_items = sorted(reached.items(), key=lambda x: x[0])\n", - "\n", - " exp_names = [item[0] for item in sorted_items]\n", - " steps = [item[1] for item in sorted_items]\n", - "\n", - " # Extract k values and set up colormap\n", - " k_values = []\n", - " if color_by_k:\n", - " import re\n", - "\n", - " for exp_name in exp_names:\n", - " # Try to extract k value from experiment name (e.g., \"k_2_seqseq\" -> 2)\n", - " match = re.search(r\"k[_\\s]*(\\d+)\", exp_name, re.IGNORECASE)\n", - " if match:\n", - " k_values.append(int(match.group(1)))\n", - " else:\n", - " k_values.append(None)\n", - "\n", - " # Check if we have valid k values\n", - " valid_k_values = [k for k in k_values if k is not None]\n", - " if valid_k_values:\n", - " k_min = min(valid_k_values)\n", - " k_max = max(valid_k_values)\n", - " norm = plt.cm.colors.Normalize(vmin=k_min, vmax=k_max)\n", - " colormap = plt.cm.get_cmap(cmap)\n", - " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n", - " else:\n", - " color_by_k = False # Fall back if no k values found\n", - "\n", - " # Create bar plot\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n", - "\n", - " # Color bars\n", - " if color_by_k and valid_k_values:\n", - " for bar, k_val in zip(bars, k_values):\n", - " if k_val is not None:\n", - " bar.set_color(scalar_map.to_rgba(k_val))\n", - " else:\n", - " bar.set_color(\"gray\") # Fallback color\n", - " else:\n", - " # Color bars by position (gradient from blue to yellow)\n", - " colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(steps)))\n", - " for bar, color in zip(bars, colors):\n", - " bar.set_color(color)\n", - "\n", - " ax.set_xticks(range(len(exp_names)))\n", - " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n", - " ax.set_xlabel(\"Experiment\", fontsize=12)\n", - " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n", - " ax.set_title(\n", - " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n", - " fontsize=13,\n", - " )\n", - " ax.grid(True, alpha=0.3, axis=\"y\")\n", - "\n", - " # Add value labels on bars\n", - " for i, (bar, step) in enumerate(zip(bars, steps)):\n", - " height = bar.get_height()\n", - " ax.text(\n", - " bar.get_x() + bar.get_width() / 2.0,\n", - " height,\n", - " f\"{int(step):,}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", - " fontsize=9,\n", - " )\n", - "\n", - " # Add colorbar if coloring by k\n", - " if color_by_k and valid_k_values:\n", - " cbar = plt.colorbar(scalar_map, ax=ax, label=\"k (sequence length)\", pad=0.02)\n", - " cbar.ax.tick_params(labelsize=10)\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # Print summary\n", - " print(\n", - " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n", - " )\n", - " print(\"-\" * 60)\n", - " for name, step in sorted_items:\n", - " print(f\" {name:30s}: {step:8,} steps\")\n", - "\n", - " if not_reached:\n", - " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n", - " for name in not_reached:\n", - " print(f\" - {name}\")\n", - "\n", - " plt.show()\n", - "\n", - " return fig, ax, times, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b015323f", - "metadata": {}, - "outputs": [], - "source": [ - "# Load template\n", - "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "# get the times without plotting\n", - "times_dict, theory_levels = calculate_time_to_plateau(\n", - " sweep_dir, template_2d, p1, p2, target_plateau_idx=1\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb63ea4a", - "metadata": {}, - "outputs": [], - "source": [ - "# Plot time to first drop (plateau index 1)\n", - "fig, ax, times, levels = plot_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx=1, # First drop\n", - " tolerance=1.1, # Loss must be <= 1.1 * theory_level\n", - " sort_by=\"name\",\n", - ")\n", - "\n", - "# Plot time to second drop (plateau index 2)\n", - "plot_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx=2,\n", - " tolerance=1.05, # Second drop\n", - " cmap=\"viridis\",\n", - " sort_by=\"name\",\n", - ")\n", - "\n", - "# Plot time to third drop (plateau index 3)\n", - "plot_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx=3,\n", - " tolerance=1.05, # Second drop\n", - " cmap=\"viridis\",\n", - " sort_by=\"name\",\n", - ")\n", - "\n", - "# Plot time to fourth drop (plateau index 4)\n", - "plot_time_to_plateau(\n", - " sweep_dir,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx=4,\n", - " tolerance=1.05, # Second drop\n", - " cmap=\"viridis\",\n", - " sort_by=\"name\",\n", - ");" - ] - }, - { - "cell_type": "markdown", - "id": "5536cdb2", - "metadata": {}, - "source": [ - "## Sequence-to-one sweep across different values of k (sequence length)" - ] - }, - { - "cell_type": "markdown", - "id": "063d9df4", - "metadata": {}, - "source": [ - "- 1st sweep dir: \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\" (k=4, 5)\n", - "- 2nd sweep dir: \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\" (k=2, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d96f443e", - "metadata": {}, - "outputs": [], - "source": [ - "def load_losses_from_multiple_sweeps(\n", - " sweep_experiments: Dict[str, List[str]], loss_type: str = \"train\", seed: int = 0\n", - ") -> Dict[str, Dict[str, np.ndarray]]:\n", - " \"\"\"\n", - " Load loss histories from experiments across multiple sweeps.\n", - "\n", - " Args:\n", - " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n", - " e.g., {sweep_dir_1: [\"exp1\", \"exp2\"], sweep_dir_2: [\"exp3\", \"exp4\"]}\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - "\n", - " Returns:\n", - " Dictionary mapping experiment names to their loss dictionaries\n", - " Note: Experiment names must be unique across sweeps\n", - " \"\"\"\n", - " all_losses = {}\n", - "\n", - " for sweep_dir, exp_names in sweep_experiments.items():\n", - " for exp_name in exp_names:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " # Store with experiment name as key\n", - " all_losses[exp_name] = losses\n", - " else:\n", - " print(\n", - " f\"Warning: No {loss_type} loss found for {exp_name} in {sweep_dir}\"\n", - " )\n", - "\n", - " return all_losses\n", - "\n", - "\n", - "def create_color_mapping_from_multiple_sweeps(\n", - " sweep_experiments: Dict[str, List[str]],\n", - " parameter_path: str,\n", - " cmap: str = \"viridis\",\n", - " seed: int = 0,\n", - " log_scale: bool = False,\n", - ") -> tuple:\n", - " \"\"\"\n", - " Create color mapping for experiments across multiple sweeps.\n", - "\n", - " Args:\n", - " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n", - " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size')\n", - " cmap: Colormap name\n", - " seed: Seed number\n", - " log_scale: Whether to use logarithmic scale for color mapping\n", - "\n", - " Returns:\n", - " Tuple of (color_mapping dict, scalar_map, param_values dict)\n", - " \"\"\"\n", - " param_values = {}\n", - "\n", - " # Extract parameters from all experiments across all sweeps\n", - " for sweep_dir, exp_names in sweep_experiments.items():\n", - " for exp_name in exp_names:\n", - " value = extract_config_parameter(sweep_dir, exp_name, parameter_path, seed)\n", - " if value is not None:\n", - " param_values[exp_name] = value\n", - "\n", - " if not param_values:\n", - " print(f\"Warning: Could not extract '{parameter_path}' from any experiments\")\n", - " return {}, None, {}\n", - "\n", - " # Create color mapping\n", - " values = list(param_values.values())\n", - " v_min = min(values)\n", - " v_max = max(values)\n", - "\n", - " # Use log or linear normalization\n", - " if log_scale:\n", - " if v_min <= 0:\n", - " print(\n", - " f\"Warning: log_scale requested but found non-positive values (min={v_min}). Using linear scale.\"\n", - " )\n", - " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n", - " else:\n", - " norm = plt.cm.colors.LogNorm(vmin=v_min, vmax=v_max)\n", - " else:\n", - " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n", - "\n", - " colormap = plt.cm.get_cmap(cmap)\n", - "\n", - " color_mapping = {}\n", - " for exp_name, value in param_values.items():\n", - " color_mapping[exp_name] = colormap(norm(value))\n", - "\n", - " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n", - "\n", - " return color_mapping, scalar_map, param_values\n", - "\n", - "\n", - "def plot_loss_comparison_multi_sweep(\n", - " sweep_experiments: Dict[str, List[str]],\n", - " loss_type: str = \"train\",\n", - " log_scale: bool = True,\n", - " figsize: tuple = (10, 6),\n", - " seed: int = 0,\n", - " remove_outliers: bool = False,\n", - " outlier_window: int = 10,\n", - " outlier_threshold: float = 3.0,\n", - " template_2d: Optional[np.ndarray] = None,\n", - " p1: Optional[int] = None,\n", - " p2: Optional[int] = None,\n", - " show_theory_bands: bool = True,\n", - " num_theory_lines: Optional[int] = None,\n", - " color_mapping: Optional[Dict[str, tuple]] = None,\n", - " colorbar_label: Optional[str] = None,\n", - " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n", - "):\n", - " \"\"\"\n", - " Plot and compare loss curves from experiments across multiple sweeps.\n", - "\n", - " Args:\n", - " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n", - " e.g., {sweep_dir_1: [\"exp1\", \"exp2\"], sweep_dir_2: [\"exp3\"]}\n", - " loss_type: 'train' or 'val'\n", - " log_scale: Whether to use log scale for both axes\n", - " figsize: Figure size tuple\n", - " seed: Seed number\n", - " remove_outliers: Whether to remove outliers\n", - " outlier_window: Window size for outlier detection\n", - " outlier_threshold: Threshold for outlier detection\n", - " template_2d: Optional 2D template array for computing theory lines\n", - " p1, p2: Template dimensions\n", - " show_theory_bands: Whether to show colored bands between theory lines\n", - " num_theory_lines: Number of theory lines to show\n", - " color_mapping: Dictionary mapping experiment names to RGBA colors\n", - " colorbar_label: Label for colorbar\n", - " scalar_map: ScalarMappable for colorbar\n", - " \"\"\"\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - "\n", - " # Compute theory lines if template is provided\n", - " theory_levels = None\n", - " if template_2d is not None:\n", - " if p1 is None or p2 is None:\n", - " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n", - "\n", - " from gagf.rnns.utils import get_power_2d_adele\n", - "\n", - " _, _, power = get_power_2d_adele(template_2d)\n", - " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - "\n", - " alpha_values = np.array(\n", - " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n", - " )\n", - " coef = 1.0 / (p1 * p2)\n", - " theory_levels = coef * alpha_values\n", - "\n", - " if num_theory_lines is not None:\n", - " theory_levels = theory_levels[: num_theory_lines + 1]\n", - "\n", - " n_bands = len(theory_levels) - 1\n", - " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n", - "\n", - " if show_theory_bands and n_bands > 0:\n", - " for i in range(n_bands):\n", - " y_top = theory_levels[i]\n", - " y_bot = theory_levels[i + 1]\n", - " ax.axhspan(\n", - " y_bot,\n", - " y_top,\n", - " facecolor=colors[i % len(colors)],\n", - " alpha=0.15,\n", - " zorder=-3,\n", - " )\n", - "\n", - " for y in theory_levels:\n", - " ax.axhline(\n", - " y=y,\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", - " linewidth=1.5,\n", - " alpha=0.7,\n", - " zorder=-2,\n", - " label=\"_nolegend_\",\n", - " )\n", - "\n", - " # Plot loss curves from all sweeps\n", - " for sweep_dir, exp_names in sweep_experiments.items():\n", - " for exp_name in exp_names:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " loss_history = losses[loss_type]\n", - "\n", - " if remove_outliers:\n", - " loss_history, outliers_found = remove_outliers_local(\n", - " loss_history, window=outlier_window, threshold=outlier_threshold\n", - " )\n", - " if outliers_found:\n", - " print(f\"Outliers from {exp_name} removed for plot\")\n", - "\n", - " # Determine color\n", - " if color_mapping and exp_name in color_mapping:\n", - " color = color_mapping[exp_name]\n", - " else:\n", - " color = None\n", - "\n", - " ax.plot(\n", - " loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color\n", - " )\n", - "\n", - " ax.set_xlabel(\"Step\", fontsize=14)\n", - " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n", - " ax.set_title(\n", - " f\"{loss_type.capitalize()} Loss Comparison (Multiple Sweeps)\", fontsize=14\n", - " )\n", - " ax.legend(fontsize=10)\n", - " ax.grid(True, alpha=0.3)\n", - "\n", - " if log_scale:\n", - " ax.set_xscale(\"log\")\n", - " ax.set_yscale(\"log\")\n", - "\n", - " if color_mapping and scalar_map is not None:\n", - " label = colorbar_label if colorbar_label else \"Parameter Value\"\n", - " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n", - " cbar.ax.tick_params(labelsize=10)\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - " return fig, ax, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9ac8e706", - "metadata": {}, - "outputs": [], - "source": [ - "# Define which experiments from which sweeps\n", - "sweep_experiments = {\n", - " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\": [\n", - " \"k_4_seqone\",\n", - " \"k_5_seqone\",\n", - " ],\n", - " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\": [\n", - " \"k_2_seqone\",\n", - " \"k_3_seqone\",\n", - " ],\n", - "}\n", - "\n", - "# Load template (assuming same template across all)\n", - "template_path = os.path.join(\n", - " list(sweep_experiments.keys())[0], # First sweep\n", - " list(sweep_experiments.values())[0][0], # First experiment\n", - " \"seed_0\",\n", - " \"template.npy\",\n", - ")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "# Create color mapping based on parameter across all sweeps\n", - "color_map, scalar_map, k_values = create_color_mapping_from_multiple_sweeps(\n", - " sweep_experiments,\n", - " \"data.k\", # or 'data.batch_size', 'training.learning_rate', etc.\n", - " cmap=\"viridis\",\n", - " log_scale=False,\n", - ")\n", - "\n", - "# Plot\n", - "plot_loss_comparison_multi_sweep(\n", - " sweep_experiments,\n", - " template_2d=template_2d,\n", - " p1=p1,\n", - " p2=p2,\n", - " color_mapping=color_map,\n", - " colorbar_label=\"k (sequence length)\",\n", - " scalar_map=scalar_map,\n", - " num_theory_lines=10,\n", - " remove_outliers=False,\n", - " log_scale=True,\n", - ")\n", - "\n", - "# Print extracted values to verify\n", - "print(\"Extracted k values:\", k_values)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1019938c", - "metadata": {}, - "outputs": [], - "source": [ - "def calculate_time_to_plateau_multi_sweep(\n", - " sweep_experiments: Dict[str, List[str]],\n", - " template_2d: np.ndarray,\n", - " p1: int,\n", - " p2: int,\n", - " target_plateau_idx: int = 1,\n", - " loss_type: str = \"train\",\n", - " seed: int = 0,\n", - " tolerance: float = 1.1,\n", - ") -> tuple:\n", - " \"\"\"\n", - " Calculate time to plateau for experiments across multiple sweeps.\n", - "\n", - " Args:\n", - " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n", - " template_2d: 2D template array\n", - " p1, p2: Template dimensions\n", - " target_plateau_idx: Which plateau to measure (1 = first drop)\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - " tolerance: Multiplier for plateau threshold\n", - "\n", - " Returns:\n", - " Tuple of (times_to_plateau dict, theory_levels array)\n", - " \"\"\"\n", - " from gagf.rnns.utils import get_power_2d_adele\n", - "\n", - " # Compute theory levels\n", - " _, _, power = get_power_2d_adele(template_2d)\n", - " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - " alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n", - " coef = 1.0 / (p1 * p2)\n", - " theory_levels = coef * alpha_values\n", - "\n", - " if target_plateau_idx >= len(theory_levels):\n", - " raise ValueError(\n", - " f\"target_plateau_idx {target_plateau_idx} exceeds available plateaus ({len(theory_levels)})\"\n", - " )\n", - "\n", - " target_level = theory_levels[target_plateau_idx]\n", - " threshold = tolerance * target_level\n", - "\n", - " times_to_plateau = {}\n", - "\n", - " # Process all experiments across all sweeps\n", - " for sweep_dir, exp_names in sweep_experiments.items():\n", - " for exp_name in exp_names:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " loss_history = losses[loss_type]\n", - "\n", - " # Find first step where loss drops below threshold\n", - " crossing_indices = np.where(loss_history <= threshold)[0]\n", - "\n", - " if len(crossing_indices) > 0:\n", - " times_to_plateau[exp_name] = crossing_indices[0]\n", - " else:\n", - " times_to_plateau[exp_name] = None # Never reached\n", - "\n", - " return times_to_plateau, theory_levels\n", - "\n", - "\n", - "def plot_time_to_plateau_multi_sweep(\n", - " sweep_experiments: Dict[str, List[str]],\n", - " template_2d: np.ndarray,\n", - " p1: int,\n", - " p2: int,\n", - " target_plateau_idx: int = 1,\n", - " loss_type: str = \"train\",\n", - " seed: int = 0,\n", - " tolerance: float = 1.1,\n", - " figsize: tuple = (10, 6),\n", - " sort_by: str = \"time\",\n", - " color_mapping: Optional[Dict[str, tuple]] = None,\n", - " colorbar_label: Optional[str] = None,\n", - " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n", - " show_not_reached: bool = True,\n", - "):\n", - " \"\"\"\n", - " Plot time to reach a target plateau for experiments across multiple sweeps.\n", - "\n", - " Args:\n", - " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n", - " template_2d: 2D template array\n", - " p1, p2: Template dimensions\n", - " target_plateau_idx: Which plateau to measure (1 = first drop)\n", - " loss_type: 'train' or 'val'\n", - " seed: Seed number\n", - " tolerance: Multiplier for plateau threshold\n", - " figsize: Figure size\n", - " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n", - " color_mapping: Dictionary mapping experiment names to RGBA colors\n", - " colorbar_label: Label for colorbar\n", - " scalar_map: ScalarMappable for colorbar\n", - " show_not_reached: Whether to show experiments that didn't reach plateau\n", - " \"\"\"\n", - " times, theory_levels = calculate_time_to_plateau_multi_sweep(\n", - " sweep_experiments,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx,\n", - " loss_type,\n", - " seed,\n", - " tolerance,\n", - " )\n", - "\n", - " # Separate reached vs not reached\n", - " reached = {k: v for k, v in times.items() if v is not None}\n", - " not_reached = {k: v for k, v in times.items() if v is None}\n", - "\n", - " if not reached and not not_reached:\n", - " print(\"No experiments found!\")\n", - " return\n", - "\n", - " # Sort experiments that reached the plateau\n", - " if reached:\n", - " if sort_by == \"time\":\n", - " sorted_reached = sorted(reached.items(), key=lambda x: x[1])\n", - " else: # alphabetical\n", - " sorted_reached = sorted(reached.items(), key=lambda x: x[0])\n", - " else:\n", - " sorted_reached = []\n", - "\n", - " # Add not-reached experiments at the end if requested\n", - " if show_not_reached and not_reached:\n", - " not_reached_sorted = sorted(not_reached.keys())\n", - " exp_names = [item[0] for item in sorted_reached] + not_reached_sorted\n", - " # Use a very large value for visualization (e.g., max reached time * 1.5)\n", - " max_reached_time = max([v for v in reached.values()]) if reached else 10000\n", - " placeholder_time = max_reached_time * 1.3\n", - " steps = [item[1] for item in sorted_reached] + [placeholder_time] * len(\n", - " not_reached_sorted\n", - " )\n", - " reached_mask = [True] * len(sorted_reached) + [False] * len(not_reached_sorted)\n", - " else:\n", - " exp_names = [item[0] for item in sorted_reached]\n", - " steps = [item[1] for item in sorted_reached]\n", - " reached_mask = [True] * len(sorted_reached)\n", - "\n", - " if not exp_names:\n", - " print(\"No experiments to plot!\")\n", - " return\n", - "\n", - " # Create bar plot\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n", - "\n", - " # Color bars\n", - " for i, (bar, exp_name, did_reach) in enumerate(zip(bars, exp_names, reached_mask)):\n", - " if not did_reach:\n", - " # Gray out experiments that didn't reach plateau\n", - " bar.set_color(\"lightgray\")\n", - " bar.set_alpha(0.4)\n", - " bar.set_hatch(\"///\")\n", - " elif color_mapping and exp_name in color_mapping:\n", - " bar.set_color(color_mapping[exp_name])\n", - " else:\n", - " # Default gradient coloring\n", - " colors = plt.cm.viridis(\n", - " np.linspace(0.2, 0.9, len([m for m in reached_mask if m]))\n", - " )\n", - " bar.set_color(colors[i] if did_reach else \"lightgray\")\n", - "\n", - " ax.set_xticks(range(len(exp_names)))\n", - " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n", - " ax.set_xlabel(\"Experiment\", fontsize=12)\n", - " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n", - " ax.set_title(\n", - " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n", - " fontsize=13,\n", - " )\n", - " ax.grid(True, alpha=0.3, axis=\"y\")\n", - "\n", - " # Add value labels on bars\n", - " for i, (bar, step, did_reach, exp_name) in enumerate(\n", - " zip(bars, steps, reached_mask, exp_names)\n", - " ):\n", - " height = bar.get_height()\n", - " if did_reach:\n", - " ax.text(\n", - " bar.get_x() + bar.get_width() / 2.0,\n", - " height,\n", - " f\"{int(step):,}\",\n", - " ha=\"center\",\n", - " va=\"bottom\",\n", - " fontsize=9,\n", - " )\n", - " else:\n", - " # Add \"Did not reach\" annotation\n", - " ax.text(\n", - " bar.get_x() + bar.get_width() / 2.0,\n", - " height / 2,\n", - " \"Did not\\nreach\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=8,\n", - " style=\"italic\",\n", - " color=\"darkgray\",\n", - " weight=\"bold\",\n", - " )\n", - "\n", - " # Add colorbar if color mapping provided (only for experiments that reached)\n", - " if color_mapping and scalar_map is not None and reached:\n", - " label = colorbar_label if colorbar_label else \"Parameter Value\"\n", - " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n", - " cbar.ax.tick_params(labelsize=10)\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # Print summary\n", - " print(\n", - " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n", - " )\n", - " print(\"-\" * 60)\n", - " if reached:\n", - " for name, step in sorted_reached:\n", - " print(f\" {name:30s}: {step:8,} steps\")\n", - "\n", - " if not_reached:\n", - " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n", - " for name in not_reached_sorted:\n", - " print(f\" - {name}\")\n", - "\n", - " plt.show()\n", - "\n", - " return fig, ax, times, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84c284b1", - "metadata": {}, - "outputs": [], - "source": [ - "# Define your experiments across multiple sweeps\n", - "sweep_experiments = {\n", - " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\": [\n", - " \"k_4_seqone\",\n", - " \"k_5_seqone\",\n", - " ],\n", - " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\": [\n", - " \"k_2_seqone\",\n", - " \"k_3_seqone\",\n", - " ],\n", - "}\n", - "\n", - "# Load template\n", - "template_path = os.path.join(\n", - " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\",\n", - " \"k_4_seqone\",\n", - " \"seed_0\",\n", - " \"template.npy\",\n", - ")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "# Create color mapping by k value\n", - "color_map, scalar_map, k_values = create_color_mapping_from_multiple_sweeps(\n", - " sweep_experiments, \"data.k\", cmap=\"viridis\", log_scale=False\n", - ")\n", - "\n", - "# Plot time to first plateau\n", - "fig, ax, times, levels = plot_time_to_plateau_multi_sweep(\n", - " sweep_experiments,\n", - " template_2d,\n", - " p1,\n", - " p2,\n", - " target_plateau_idx=1, # First drop\n", - " tolerance=1.1,\n", - " color_mapping=color_map,\n", - " colorbar_label=\"k (sequence length)\",\n", - " scalar_map=scalar_map,\n", - " sort_by=\"name\", # or 'time'\n", - " show_not_reached=True, # Show experiments that didn't reach plateau\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "78bfc515", - "metadata": {}, - "source": [ - "## Batch size sweep (seq-to-one, k=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ac01c062", - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "\n", - "\n", - "def extract_config_parameter(\n", - " sweep_dir: str, experiment_name: str, parameter_path: str, seed: int = 0\n", - ") -> any:\n", - " \"\"\"\n", - " Extract a parameter value from an experiment's config file.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " experiment_name: Name of experiment\n", - " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size', 'training.learning_rate')\n", - " seed: Seed number (for getting from seed_X/config.yaml)\n", - "\n", - " Returns:\n", - " Parameter value or None if not found\n", - " \"\"\"\n", - " # Try configs directory first\n", - " config_path = Path(sweep_dir) / \"configs\" / f\"{experiment_name}_config.yaml\"\n", - "\n", - " # If not there, try seed_X directory\n", - " if not config_path.exists():\n", - " config_path = Path(sweep_dir) / experiment_name / f\"seed_{seed}\" / \"config.yaml\"\n", - "\n", - " if not config_path.exists():\n", - " return None\n", - "\n", - " try:\n", - " with open(config_path, \"r\") as f:\n", - " config = yaml.safe_load(f)\n", - "\n", - " # Navigate through nested structure using dot notation\n", - " value = config\n", - " for key in parameter_path.split(\".\"):\n", - " value = value[key]\n", - "\n", - " return value\n", - " except (KeyError, TypeError):\n", - " return None\n", - "\n", - "\n", - "def create_color_mapping(\n", - " sweep_dir: str,\n", - " parameter_path: str,\n", - " experiments: Optional[List[str]] = None,\n", - " cmap: str = \"viridis\",\n", - " seed: int = 0,\n", - " log_scale: bool = False,\n", - ") -> tuple:\n", - " \"\"\"\n", - " Create a color mapping for experiments based on a config parameter.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size')\n", - " experiments: List of experiment names (None = all)\n", - " cmap: Colormap name\n", - " seed: Seed number\n", - " log_scale: Whether to use logarithmic scale for color mapping\n", - "\n", - " Returns:\n", - " Tuple of (color_mapping dict, scalar_map, param_values dict)\n", - " \"\"\"\n", - " if experiments is None:\n", - " experiments = get_sweep_experiments(sweep_dir)\n", - "\n", - " # Extract parameter values for all experiments\n", - " param_values = {}\n", - " for exp_name in experiments:\n", - " value = extract_config_parameter(sweep_dir, exp_name, parameter_path, seed)\n", - " if value is not None:\n", - " param_values[exp_name] = value\n", - "\n", - " if not param_values:\n", - " print(f\"Warning: Could not extract '{parameter_path}' from any experiments\")\n", - " return {}, None, {}\n", - "\n", - " # Create color mapping\n", - " values = list(param_values.values())\n", - " v_min = min(values)\n", - " v_max = max(values)\n", - "\n", - " # Use log or linear normalization\n", - " if log_scale:\n", - " if v_min <= 0:\n", - " print(\n", - " f\"Warning: log_scale requested but found non-positive values (min={v_min}). Using linear scale.\"\n", - " )\n", - " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n", - " else:\n", - " norm = plt.cm.colors.LogNorm(vmin=v_min, vmax=v_max)\n", - " else:\n", - " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n", - "\n", - " colormap = plt.cm.get_cmap(cmap)\n", - "\n", - " color_mapping = {}\n", - " for exp_name, value in param_values.items():\n", - " color_mapping[exp_name] = colormap(norm(value))\n", - "\n", - " # Also return the scalar mappable for colorbar\n", - " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n", - "\n", - " return color_mapping, scalar_map, param_values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2da02fb", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_loss_comparison(\n", - " sweep_dir: str,\n", - " experiments: Optional[List[str]] = None,\n", - " loss_type: str = \"train\",\n", - " log_scale: bool = True,\n", - " figsize: tuple = (10, 6),\n", - " seed: int = 0,\n", - " remove_outliers: bool = False,\n", - " outlier_window: int = 10,\n", - " outlier_threshold: float = 3.0,\n", - " template_2d: Optional[np.ndarray] = None,\n", - " p1: Optional[int] = None,\n", - " p2: Optional[int] = None,\n", - " show_theory_bands: bool = True,\n", - " num_theory_lines: Optional[int] = None,\n", - " color_mapping: Optional[Dict[str, tuple]] = None,\n", - " colorbar_label: Optional[str] = None,\n", - " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n", - "):\n", - " \"\"\"\n", - " Plot and compare loss curves from multiple experiments.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " experiments: List of experiment names to plot (None = all experiments)\n", - " loss_type: 'train' or 'val'\n", - " log_scale: Whether to use log scale for both axes\n", - " figsize: Figure size tuple\n", - " seed: Seed number (default: 0)\n", - " remove_outliers: Whether to remove outliers using local outlier replacement\n", - " outlier_window: Window size for outlier detection (default: 10)\n", - " outlier_threshold: Threshold in standard deviations for outlier detection (default: 3.0)\n", - " template_2d: Optional 2D template array for computing theory lines\n", - " p1: First dimension of template (required if template_2d is provided)\n", - " p2: Second dimension of template (required if template_2d is provided)\n", - " show_theory_bands: Whether to show colored bands between theory lines (default: True)\n", - " num_theory_lines: Number of theory lines to show (default: None = show all)\n", - " color_mapping: Dictionary mapping experiment names to RGBA colors\n", - " colorbar_label: Label for colorbar (if color_mapping provided)\n", - " scalar_map: ScalarMappable for colorbar (if color_mapping provided)\n", - " \"\"\"\n", - " if experiments is None:\n", - " experiments = get_sweep_experiments(sweep_dir)\n", - "\n", - " fig, ax = plt.subplots(figsize=figsize)\n", - "\n", - " # Compute theory lines if template is provided\n", - " theory_levels = None\n", - " if template_2d is not None:\n", - " if p1 is None or p2 is None:\n", - " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n", - "\n", - " # Import the helper function (assuming it's in utils.py)\n", - " from gagf.rnns.utils import get_power_2d_adele\n", - "\n", - " # Compute power spectrum of template\n", - " _, _, power = get_power_2d_adele(template_2d)\n", - " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n", - "\n", - " # Theory levels (cumulative tail sums)\n", - " alpha_values = np.array(\n", - " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n", - " )\n", - " coef = 1.0 / (p1 * p2)\n", - " theory_levels = coef * alpha_values # strictly decreasing\n", - "\n", - " # Limit number of lines if specified\n", - " if num_theory_lines is not None:\n", - " theory_levels = theory_levels[: num_theory_lines + 1]\n", - "\n", - " # Generate colors for bands\n", - " n_bands = len(theory_levels) - 1\n", - " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n", - "\n", - " # Draw colored bands between theory lines\n", - " if show_theory_bands and n_bands > 0:\n", - " for i in range(n_bands):\n", - " y_top = theory_levels[i]\n", - " y_bot = theory_levels[i + 1]\n", - " ax.axhspan(\n", - " y_bot,\n", - " y_top,\n", - " facecolor=colors[i % len(colors)],\n", - " alpha=0.15,\n", - " zorder=-3,\n", - " )\n", - "\n", - " # Draw the black theory lines\n", - " for y in theory_levels:\n", - " ax.axhline(\n", - " y=y,\n", - " color=\"black\",\n", - " linestyle=\"--\",\n", - " linewidth=1.5,\n", - " alpha=0.7,\n", - " zorder=-2,\n", - " label=\"_nolegend_\",\n", - " )\n", - "\n", - " # Plot loss curves\n", - " for exp_name in experiments:\n", - " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n", - " if loss_type in losses:\n", - " loss_history = losses[loss_type]\n", - "\n", - " # Apply outlier removal if requested\n", - " if remove_outliers:\n", - " loss_history, outliers_found = remove_outliers_local(\n", - " loss_history, window=outlier_window, threshold=outlier_threshold\n", - " )\n", - " if outliers_found:\n", - " print(f\"Outliers from {exp_name} removed for plot\")\n", - "\n", - " # Determine color\n", - " if color_mapping and exp_name in color_mapping:\n", - " color = color_mapping[exp_name]\n", - " else:\n", - " color = None # Use default color cycle\n", - "\n", - " ax.plot(loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color)\n", - "\n", - " ax.set_xlabel(\"Step\", fontsize=14)\n", - " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n", - " title = f\"{loss_type.capitalize()} Loss Comparison - {Path(sweep_dir).name}\"\n", - " if remove_outliers:\n", - " title += \" (outliers removed)\"\n", - " ax.set_title(title, fontsize=14)\n", - " ax.legend(fontsize=10)\n", - " ax.grid(True, alpha=0.3)\n", - "\n", - " if log_scale:\n", - " ax.set_xscale(\"log\")\n", - " ax.set_yscale(\"log\")\n", - "\n", - " # Add colorbar if color mapping provided\n", - " if color_mapping and scalar_map is not None:\n", - " label = colorbar_label if colorbar_label else \"Parameter Value\"\n", - " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n", - " cbar.ax.tick_params(labelsize=10)\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - " return fig, ax, theory_levels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b2ce522", - "metadata": {}, - "outputs": [], - "source": [ - "# Set up your sweep directory\n", - "batch_sweep_dir = \"/home/facosta/group-agf/sweeps/batch_sweep_20251113_171834\"\n", - "\n", - "# Get all experiments in the sweep\n", - "batch_experiments = get_sweep_experiments(batch_sweep_dir)\n", - "print(f\"Found experiments: {batch_experiments}\")\n", - "\n", - "# Load losses for all experiments\n", - "batch_all_losses = load_all_sweep_losses(batch_sweep_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "520500f3", - "metadata": {}, - "outputs": [], - "source": [ - "# First, load the template from one of your experiments\n", - "# (assuming they all use the same template)\n", - "template_path = os.path.join(batch_sweep_dir, \"batch_1000\", \"seed_0\", \"template.npy\")\n", - "template_2d = np.load(template_path)\n", - "p1, p2 = template_2d.shape\n", - "\n", - "plt.imshow(template_2d)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9761849", - "metadata": {}, - "outputs": [], - "source": [ - "# Example 1: Color by batch_size\n", - "color_map, scalar_map, batch_values = create_color_mapping(\n", - " batch_sweep_dir,\n", - " \"data.batch_size\",\n", - " cmap=\"viridis\",\n", - " log_scale=True,\n", - ")\n", - "\n", - "plot_loss_comparison(\n", - " batch_sweep_dir,\n", - " template_2d=template_2d,\n", - " p1=p1,\n", - " p2=p2,\n", - " color_mapping=color_map,\n", - " colorbar_label=\"Batch Size\",\n", - " scalar_map=scalar_map,\n", - " num_theory_lines=10,\n", - " remove_outliers=True,\n", - " log_scale=False,\n", - ")\n", - "\n", - "\n", - "# Print the extracted values to verify\n", - "print(\"Batch sizes:\", batch_values)\n", - "\n", - "# Print the theory levels (plateau values)\n", - "print(\"Theory plateau levels:\")\n", - "for i, level in enumerate(theory_levels):\n", - " print(f\" Plateau {i}: {level:.6e}\")" - ] - }, - { - "cell_type": "markdown", - "id": "7acdd373", - "metadata": {}, - "source": [ - "## MLP Scaling Sweep Heatmap\n", - "\n", - "Visualize the relationship between sequence length (k) and hidden dimension (width) for SequentialMLP.\n", - "\n", - "**Sweep parameters:**\n", - "- Model: SequentialMLP\n", - "- Dimension: 1, p = 10\n", - "- k values: 2, 3, 4, 5, 6 (5 values)\n", - "- hidden_dim values: 60, 360, 2160, 12960, 77760 (5 values = 10×6¹ through 10×6⁵)\n", - "- num_steps varies with k: k=2→50k, k=3→100k, k=4→150k, k=5→200k, k=6→250k\n", - "- Total: 25 experiments × 3 seeds = 75 runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92d3289e", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_scaling_heatmap(\n", - " sweep_dir: str,\n", - " k_values: list = [2, 3, 4, 5, 6, 7, 8],\n", - " hidden_dims: list = [60, 360, 2160, 12960, 77760, 466560, 2799360],\n", - " use_log_scale: bool = True,\n", - " save_path: str = None\n", - "): \n", - " \"\"\"\n", - " Create a heatmap of final train loss vs k and hidden_dim.\n", - " \n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k values (x-axis)\n", - " hidden_dims: List of hidden dimension values (y-axis)\n", - " use_log_scale: Whether to use log scale for the loss values\n", - " save_path: Optional path to save the figure\n", - " \"\"\"\n", - " # Load results\n", - " grid, std_grid = load_sweep_results_grid(sweep_dir, k_values, hidden_dims)\n", - " \n", - " # Apply log scale if requested\n", - " plot_grid = np.log10(grid) if use_log_scale else grid\n", - " \n", - " # Create figure\n", - " fig, ax = plt.subplots(figsize=(12, 10)) \n", - " \n", - " \n", - " # Create heatmap\n", - " im = ax.imshow(plot_grid, aspect='auto', cmap='viridis', origin='lower')\n", - " \n", - " # Set ticks and labels\n", - " ax.set_xticks(range(len(k_values)))\n", - " ax.set_yticks(range(len(hidden_dims)))\n", - " ax.set_xticklabels(k_values)\n", - " ax.set_yticklabels([f\"{h:,}\" for h in hidden_dims])\n", - " \n", - " # Labels\n", - " ax.set_xlabel('Sequence Length (k)', fontsize=14)\n", - " ax.set_ylabel('Hidden Dimension (width)', fontsize=14)\n", - " title = 'Final Train Loss: SequentialMLP Scaling'\n", - " if use_log_scale:\n", - " title += ' (log₁₀)'\n", - " ax.set_title(title, fontsize=16, pad=20)\n", - " \n", - " # Add colorbar\n", - " cbar = plt.colorbar(im, ax=ax)\n", - " cbar_label = 'log₁₀(Train Loss)' if use_log_scale else 'Train Loss'\n", - " cbar.set_label(cbar_label, fontsize=12)\n", - " \n", - " # Add text annotations\n", - "\n", - " # Add text annotations\n", - " for i in range(len(hidden_dims)):\n", - " for j in range(len(k_values)):\n", - " if not np.isnan(grid[i, j]):\n", - " text_val = f\"{grid[i, j]:.2e}\"\n", - " text_color = 'white' if plot_grid[i, j] < plot_grid[~np.isnan(plot_grid)].mean() else 'black'\n", - " ax.text(j, i, text_val, ha='center', va='center', \n", - " color=text_color, fontsize=8)\n", - " \n", - " plt.tight_layout()\n", - " \n", - " if save_path:\n", - " plt.savefig(save_path, dpi=300, bbox_inches='tight')\n", - " print(f\"Figure saved to: {save_path}\")\n", - " \n", - " plt.show()\n", - " \n", - " return grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e5c22df", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b10c9d83", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/seq_mlp.ipynb b/notebooks/seq_mlp.ipynb deleted file mode 100644 index 5dd9168..0000000 --- a/notebooks/seq_mlp.ipynb +++ /dev/null @@ -1,2071 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "4265f1a8", - "metadata": {}, - "source": [ - "# MLP Scaling: $H$ vs $k$\n", - "\n", - "Hidden neurons vs sequence length scaling experiments." - ] - }, - { - "cell_type": "markdown", - "id": "5a05ce99", - "metadata": {}, - "source": [ - "## Set up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5cfe1142", - "metadata": {}, - "outputs": [], - "source": [ - "# autoreload\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "# jupyter black formatter\n", - "%load_ext jupyter_black\n", - "\n", - "import subprocess\n", - "import os\n", - "import sys\n", - "\n", - "gitroot_path = subprocess.check_output(\n", - " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n", - ").strip()\n", - "\n", - "os.chdir(gitroot_path)\n", - "print(\"Working directory: \", os.getcwd())\n", - "\n", - "if gitroot_path not in sys.path:\n", - " sys.path.insert(0, gitroot_path)\n", - "print(\"Directory added to path: \", gitroot_path)\n", - "\n", - "import yaml\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "markdown", - "id": "15a42140", - "metadata": {}, - "source": [ - "## Specify experiment directory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2ebd6750", - "metadata": {}, - "outputs": [], - "source": [ - "sweep_dir = \"/data/facosta/sweeps/sweep_mlp_scaling_20251212_161329\"\n", - "os.path.exists(sweep_dir)" - ] - }, - { - "cell_type": "markdown", - "id": "cb3acce2", - "metadata": {}, - "source": [ - "### Final Loss Heatmap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "af291059", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid(sweep_dir: str, k_values: list, hidden_dims: list):\n", - " \"\"\"\n", - " Load sweep results and organize into a grid for heatmap visualization.\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - "\n", - " Returns:\n", - " grid: 2D numpy array with shape (len(hidden_dims), len(k_values))\n", - " containing mean final train losses\n", - " std_grid: 2D numpy array with standard deviations (if multiple seeds)\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " # Initialize grids\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " # Load results for each experiment\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " print(f\"Warning: Experiment {exp_name} not found\")\n", - " continue\n", - "\n", - " # Load experiment summary\n", - " summary_file = exp_dir / \"experiment_summary.yaml\"\n", - " if summary_file.exists():\n", - " with open(summary_file, \"r\") as f:\n", - " summary = yaml.safe_load(f)\n", - "\n", - " # Get mean train loss\n", - " if \"train_loss_stats\" in summary:\n", - " grid[i, j] = summary[\"train_loss_stats\"][\"mean\"]\n", - " std_grid[i, j] = summary[\"train_loss_stats\"][\"std\"]\n", - " else:\n", - " print(f\"Warning: No train_loss_stats in {exp_name}\")\n", - " else:\n", - " print(f\"Warning: No summary file for {exp_name}\")\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "54e321b2", - "metadata": {}, - "outputs": [], - "source": [ - "k_values = [2, 3, 4, 5, 6, 7, 8]\n", - "\n", - "# hidden_dims = [60, 360, 2160, 12960, 77760]\n", - "hidden_dims = [6, 6**2, 6**3, 6**4, 6**5, 6**6]\n", - "\n", - "grid, _ = load_sweep_results_grid(sweep_dir, k_values, hidden_dims)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f62021a", - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(8, 6))\n", - "plt.imshow(grid, aspect=\"auto\", norm=None)\n", - "\n", - "# Correct labels: rows = hidden_dims (y), columns = k_values (x)\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "\n", - "# Set tick labels to show actual values\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - "\n", - "plt.gca().invert_yaxis()\n", - "\n", - "plt.title(\"Final Train Loss\")\n", - "\n", - "plt.colorbar(label=\"Final Train Loss\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fbcf4b02", - "metadata": {}, - "source": [ - "### Loss Curve Integral Heatmap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4bb980bb", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_integral(sweep_dir: str, k_values: list, hidden_dims: list):\n", - " \"\"\"\n", - " Load sweep results and compute integral of loss curves.\n", - "\n", - " Returns:\n", - " grid: 2D array with mean integral of loss curves\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Collect integrals from all seeds\n", - " integrals = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " # Compute integral using trapezoidal rule\n", - " integral = np.trapz(loss_history)\n", - " integrals.append(integral)\n", - "\n", - " if integrals:\n", - " grid[i, j] = np.mean(integrals)\n", - " std_grid[i, j] = np.std(integrals) if len(integrals) > 1 else 0.0\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8e5017cd", - "metadata": {}, - "outputs": [], - "source": [ - "integral_grid, integral_std = load_sweep_results_grid_integral(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "from matplotlib.colors import LogNorm, SymLogNorm\n", - "\n", - "plt.figure(figsize=(8, 6))\n", - "plt.imshow(integral_grid, aspect=\"auto\", norm=LogNorm())\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - "\n", - "plt.gca().invert_yaxis()\n", - "plt.colorbar(label=\"Loss Curve Integral\")\n", - "plt.title(\"Loss Curve Integral\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "6b9956cf", - "metadata": {}, - "source": [ - "### Steps to Convergence Heatmap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4bf3dbcc", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_convergence(\n", - " sweep_dir: str, k_values: list, hidden_dims: list, reduction_threshold: float = 0.99\n", - "):\n", - " \"\"\"\n", - " Load sweep results and compute steps to convergence.\n", - "\n", - " Convergence is defined as reaching `reduction_threshold` loss reduction\n", - " (e.g., 0.99 = 99% reduction from initial loss).\n", - "\n", - " If convergence is not reached, the grid point is set to NaN (blacked out).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - " reduction_threshold: Fraction of loss reduction to consider converged\n", - "\n", - " Returns:\n", - " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Collect convergence steps from all seeds\n", - " convergence_steps = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " initial_loss = loss_history[0]\n", - "\n", - " if initial_loss > 0:\n", - " # Compute reduction at each step\n", - " reductions = 1 - loss_history / initial_loss\n", - "\n", - " # Find first step where reduction >= threshold\n", - " converged_mask = reductions >= reduction_threshold\n", - " if np.any(converged_mask):\n", - " step = np.argmax(converged_mask) # First True\n", - " convergence_steps.append(step)\n", - " # else: Never converged - don't add to list\n", - "\n", - " if convergence_steps:\n", - " grid[i, j] = np.mean(convergence_steps)\n", - " std_grid[i, j] = (\n", - " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n", - " )\n", - " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b083a66", - "metadata": {}, - "outputs": [], - "source": [ - "reduction_threshold = 0.6\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n", - ")\n", - "plt.figure(figsize=(10, 6)) # Made slightly wider to accommodate legend\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length ($k$)\")\n", - "plt.ylabel(\"Hidden Dimension $H$\")\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "\n", - "# Create y-tick labels with both power notation and actual values\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n", - "\n", - "plt.step(\n", - " x_step,\n", - " y_step,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=\"Theory boundary ($H > 6^{k-1}$)\",\n", - ")\n", - "\n", - "# Place legend outside the plot area (to the right)\n", - "plt.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True)\n", - "\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "plt.title(f\"Steps to {reduction_threshold*100}% Convergence (black = did not converge)\")\n", - "plt.tight_layout() # Adjust layout to prevent clipping\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "130132a9", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_spikiness(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dims: list,\n", - "):\n", - " \"\"\"\n", - " Compute fraction of training steps where loss increased (instability).\n", - "\n", - " Returns:\n", - " grid: 2D array with mean frac_upward across seeds\n", - " std_grid: 2D array with standard deviations\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " frac_upwards = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " log_loss = np.log10(loss_history + 1e-10)\n", - " log_changes = np.diff(log_loss)\n", - "\n", - " # Fraction of steps where loss went UP\n", - " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n", - " frac_upwards.append(frac_upward)\n", - "\n", - " if frac_upwards:\n", - " grid[i, j] = np.mean(frac_upwards)\n", - " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4fba73d0", - "metadata": {}, - "outputs": [], - "source": [ - "# Load both convergence and spikiness data\n", - "reduction_threshold = 0.6 # Adjust as needed\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n", - ")\n", - "\n", - "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "# Create a masked version of stability_grid where non-converged runs are NaN\n", - "stability_grid_masked = stability_grid.copy()\n", - "stability_grid_masked[np.isnan(conv_grid)] = np.nan # Mask non-converged runs\n", - "\n", - "# Create plot\n", - "plt.figure(figsize=(10, 6))\n", - "\n", - "# Use a colormap with bad values (NaN) shown as black\n", - "cmap = plt.cm.plasma.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "\n", - "plt.imshow(stability_grid_masked, aspect=\"equal\", cmap=cmap, vmin=0, vmax=0.5)\n", - "\n", - "plt.xlabel(\"Sequence Length ($k$)\")\n", - "plt.ylabel(\"Hidden Dimension $H$\")\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "\n", - "# Create y-tick labels with both power notation and actual values\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n", - "\n", - "plt.step(\n", - " x_step,\n", - " y_step,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=\"Theory boundary ($H > 6^{k-1}$)\",\n", - ")\n", - "\n", - "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n", - "\n", - "plt.colorbar(label=\"Fraction of Upward Steps\")\n", - "plt.title(\n", - " f\"Training Instability\\n(black = did not converge)\",\n", - " fontsize=13,\n", - " fontweight=\"bold\",\n", - ")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Print summary\n", - "n_converged = np.sum(~np.isnan(conv_grid))\n", - "n_not_converged = np.sum(np.isnan(conv_grid))\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"Converged runs: {n_converged} ({100*n_converged/conv_grid.size:.1f}%)\")\n", - "print(\n", - " f\"Did not converge (black): {n_not_converged} ({100*n_not_converged/conv_grid.size:.1f}%)\"\n", - ")\n", - "print(f\"{'='*60}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a228687", - "metadata": {}, - "outputs": [], - "source": [ - "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "# Create side-by-side subplots\n", - "plt.figure(figsize=(10, 6))\n", - "\n", - "# ========== RIGHT PANEL: SPIKINESS ==========\n", - "# Use a different colormap - plasma, magma, or RdYlGn_r work well\n", - "plt.imshow(stability_grid, aspect=\"equal\", cmap=\"plasma\", vmin=0, vmax=0.5)\n", - "\n", - "\n", - "plt.xlabel(\"Sequence Length ($k$)\")\n", - "plt.ylabel(\"Hidden Dimension $H$\")\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "\n", - "# Create y-tick labels with both power notation and actual values\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n", - "\n", - "plt.step(\n", - " x_step,\n", - " y_step,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=\"Theory boundary ($H > 6^{k-1}$)\",\n", - ")\n", - "\n", - "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n", - "\n", - "plt.colorbar(label=\"Fraction of Upward Steps\")\n", - "# cbar2.ax.axhline(0.3, color=\"white\", linewidth=2, linestyle=\"--\") # Mark threshold\n", - "plt.title(\n", - " \"Training Instability\\n(higher = more unstable)\", fontsize=13, fontweight=\"bold\"\n", - ")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5eb6bb45", - "metadata": {}, - "outputs": [], - "source": [ - "# Load both metrics\n", - "reduction_threshold = 0.6\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n", - ")\n", - "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "# Create binary mask: 1 if converged AND spiky, 0 otherwise\n", - "spikiness_threshold = 0.3 # Adjust this threshold as needed\n", - "\n", - "converged_and_spiky = np.zeros_like(conv_grid)\n", - "for i in range(len(hidden_dims)):\n", - " for j in range(len(k_values)):\n", - " converged = not np.isnan(conv_grid[i, j])\n", - " spiky = stability_grid[i, j] > spikiness_threshold\n", - "\n", - " if converged and spiky:\n", - " converged_and_spiky[i, j] = 1.0\n", - " else:\n", - " converged_and_spiky[i, j] = np.nan # Will show as white\n", - "\n", - "# Plot\n", - "plt.figure(figsize=(8, 6.5))\n", - "\n", - "# Custom colormap: white for NaN, red for 1\n", - "cmap = plt.cm.Reds.copy()\n", - "cmap.set_bad(color=\"white\")\n", - "\n", - "im = plt.imshow(converged_and_spiky, aspect=\"equal\", cmap=cmap, vmin=0, vmax=1)\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\", fontsize=12)\n", - "plt.ylabel(\"Hidden Dimension\", fontsize=12)\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "# Add theory boundary\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "y_step = np.minimum(x_step, len(hidden_dims))\n", - "plt.step(\n", - " x_step,\n", - " y_step,\n", - " where=\"post\",\n", - " color=\"blue\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=\"Theory boundary\",\n", - ")\n", - "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n", - "\n", - "# Add text annotations showing spikiness values in colored cells\n", - "for i in range(len(hidden_dims)):\n", - " for j in range(len(k_values)):\n", - " if converged_and_spiky[i, j] == 1.0:\n", - " spikiness_val = stability_grid[i, j]\n", - " plt.text(\n", - " j,\n", - " i,\n", - " f\"{spikiness_val:.2f}\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=9,\n", - " color=\"white\",\n", - " fontweight=\"bold\",\n", - " )\n", - "\n", - "plt.colorbar(im, label=\"Converged & Spiky\", ticks=[0, 1])\n", - "plt.title(\n", - " f\"Converged BUT Spiky Runs\\n(threshold: frac_upward > {spikiness_threshold})\",\n", - " fontsize=13,\n", - " fontweight=\"bold\",\n", - ")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Print summary\n", - "n_converged_spiky = np.sum(converged_and_spiky == 1.0)\n", - "n_converged_total = np.sum(~np.isnan(conv_grid))\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"SUMMARY: Converged & Spiky Runs\")\n", - "print(f\"{'='*60}\")\n", - "print(f\"Spikiness threshold: {spikiness_threshold} (frac_upward)\")\n", - "print(f\"Converged & spiky: {n_converged_spiky}\")\n", - "print(f\"Total converged: {n_converged_total}\")\n", - "print(f\"Percentage: {100*n_converged_spiky/n_converged_total:.1f}%\")\n", - "print(f\"{'='*60}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c7b2bf5c", - "metadata": {}, - "outputs": [], - "source": [ - "# Load both metrics\n", - "reduction_threshold = 0.6\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n", - ")\n", - "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "# Parameters\n", - "spikiness_threshold = 0.3\n", - "\n", - "# Create modified grid for plotting\n", - "# Strategy: Use a modified colormap and data array\n", - "plot_grid = conv_grid.copy()\n", - "\n", - "# Create mask for spiky converged runs\n", - "spiky_mask = np.zeros_like(conv_grid, dtype=bool)\n", - "for i in range(len(hidden_dims)):\n", - " for j in range(len(k_values)):\n", - " converged = not np.isnan(conv_grid[i, j])\n", - " spiky = stability_grid[i, j] > spikiness_threshold\n", - " if converged and spiky:\n", - " spiky_mask[i, j] = True\n", - "\n", - "# Plot\n", - "fig, ax = plt.subplots(figsize=(10, 6.5))\n", - "\n", - "# First: plot convergence grid with viridis_r (will handle black for NaN)\n", - "cmap_conv = plt.cm.viridis_r.copy()\n", - "cmap_conv.set_bad(color=\"black\")\n", - "im = ax.imshow(plot_grid, aspect=\"equal\", cmap=cmap_conv, norm=LogNorm())\n", - "\n", - "# Second: overlay red patches for spiky converged runs\n", - "for i in range(len(hidden_dims)):\n", - " for j in range(len(k_values)):\n", - " if spiky_mask[i, j]:\n", - " # Draw red square\n", - " rect = plt.Rectangle(\n", - " (j - 0.5, i - 0.5),\n", - " 1,\n", - " 1,\n", - " facecolor=\"red\",\n", - " edgecolor=\"darkred\",\n", - " linewidth=2,\n", - " alpha=0.9,\n", - " )\n", - " ax.add_patch(rect)\n", - "\n", - " # # Add convergence value in white text\n", - " # conv_val = conv_grid[i, j]\n", - " # ax.text(\n", - " # j,\n", - " # i,\n", - " # f\"{int(conv_val)}\",\n", - " # ha=\"center\",\n", - " # va=\"center\",\n", - " # fontsize=8,\n", - " # color=\"white\",\n", - " # fontweight=\"bold\",\n", - " # )\n", - "\n", - "# Formatting\n", - "ax.set_xlabel(\"Sequence Length (k)\", fontsize=12)\n", - "ax.set_ylabel(\"Hidden Dimension\", fontsize=12)\n", - "ax.set_xticks(range(len(k_values)))\n", - "ax.set_xticklabels(k_values)\n", - "\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "ax.set_yticks(range(len(hidden_dims)))\n", - "ax.set_yticklabels(ytick_labels)\n", - "ax.invert_yaxis()\n", - "\n", - "# Add theory boundary\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "y_step = np.minimum(x_step, len(hidden_dims))\n", - "ax.step(\n", - " x_step,\n", - " y_step,\n", - " where=\"post\",\n", - " color=\"cyan\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=\"Theory boundary\",\n", - ")\n", - "\n", - "# Custom legend\n", - "from matplotlib.patches import Patch\n", - "\n", - "legend_elements = [\n", - " Patch(facecolor=\"black\", label=\"Did not converge\"),\n", - " Patch(\n", - " facecolor=\"red\",\n", - " edgecolor=\"darkred\",\n", - " linewidth=2,\n", - " label=f\"Spiky (frac_up > {spikiness_threshold})\",\n", - " ),\n", - " Patch(facecolor=\"yellow\", label=\"Smooth convergence\"),\n", - " plt.Line2D(\n", - " [0], [0], color=\"cyan\", linewidth=3, linestyle=\"--\", label=\"Theory boundary\"\n", - " ),\n", - "]\n", - "ax.legend(\n", - " handles=legend_elements,\n", - " loc=\"upper center\",\n", - " bbox_to_anchor=(0.5, -0.12),\n", - " fontsize=10,\n", - " frameon=True,\n", - " ncol=4,\n", - ")\n", - "\n", - "plt.colorbar(im, ax=ax, label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "ax.set_title(\"Convergence Speed & Spikiness Combined\", fontsize=13, fontweight=\"bold\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Print summary\n", - "n_not_converged = np.sum(np.isnan(conv_grid))\n", - "n_converged_spiky = np.sum(spiky_mask)\n", - "n_converged_smooth = np.sum(~np.isnan(conv_grid)) - n_converged_spiky\n", - "total = conv_grid.size\n", - "\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"SUMMARY\")\n", - "print(f\"{'='*60}\")\n", - "print(\n", - " f\"Did not converge (black): {n_not_converged:3d} ({100*n_not_converged/total:.1f}%)\"\n", - ")\n", - "print(\n", - " f\"Spiky converged (red): {n_converged_spiky:3d} ({100*n_converged_spiky/total:.1f}%)\"\n", - ")\n", - "print(\n", - " f\"Smooth converged (color): {n_converged_smooth:3d} ({100*n_converged_smooth/total:.1f}%)\"\n", - ")\n", - "print(f\"{'='*60}\\n\")" - ] - }, - { - "cell_type": "markdown", - "id": "2f5cbbf5", - "metadata": {}, - "source": [ - "### Curve plot: Convergence steps vs Sequence Length $k$ for different hidden dimensions\n", - "- x-axis: sequence length $k$\n", - "- y-axis: number of steps to convergence\n", - "- different curves for different hidden dimensions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "adddc2c6", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_convergence_vs_k(\n", - " conv_grid,\n", - " k_values,\n", - " hidden_dims,\n", - " save_path=None,\n", - " show=True,\n", - " log_x=True,\n", - " log_y=True,\n", - " reduction_threshold=0.9,\n", - "):\n", - " \"\"\"\n", - " Plot steps to convergence vs sequence length k for different hidden dimensions.\n", - "\n", - " Args:\n", - " conv_grid: 2D array (len(hidden_dims), len(k_values)) with convergence steps\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - " save_path: Where to save the plot\n", - " show: Whether to display the plot\n", - " log_x: Whether to use log scale for x-axis\n", - " log_y: Whether to use log scale for y-axis\n", - " reduction_threshold: Threshold used for convergence\n", - " \"\"\"\n", - " fig, ax = plt.subplots(figsize=(10, 6))\n", - "\n", - " # Use a nice sequential colormap for different widths\n", - " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(hidden_dims)))\n", - "\n", - " for i, (h, color) in enumerate(zip(hidden_dims, colors)):\n", - " # Extract convergence steps for this hidden dim across all k values\n", - " steps_for_h = conv_grid[i, :]\n", - "\n", - " # Only plot converged points\n", - " converged_mask = ~np.isnan(steps_for_h)\n", - " k_converged = np.array(k_values)[converged_mask]\n", - " steps_converged = steps_for_h[converged_mask]\n", - "\n", - " if len(steps_converged) > 0:\n", - " # Plot with line and markers\n", - " ax.plot(\n", - " k_converged,\n", - " steps_converged,\n", - " color=color,\n", - " marker=\"o\",\n", - " markersize=7,\n", - " linewidth=2.5,\n", - " label=f\"h={h:,}\",\n", - " markeredgewidth=0.5,\n", - " markeredgecolor=\"white\",\n", - " )\n", - "\n", - " # Formatting\n", - " ax.set_xlabel(\"Sequence Length ($k$)\", fontsize=14)\n", - " ax.set_ylabel(\"Steps to Convergence\", fontsize=14)\n", - " ax.set_title(\n", - " f\"Steps to {reduction_threshold*100}% Convergence vs Sequence Length $k$\",\n", - " fontsize=16,\n", - " )\n", - " if log_y:\n", - " ax.set_yscale(\"log\")\n", - " if log_x:\n", - " ax.set_xscale(\"log\")\n", - " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n", - " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n", - "\n", - " # Make k values discrete on x-axis\n", - " ax.set_xticks(k_values)\n", - " ax.set_xticklabels(k_values)\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if save_path:\n", - " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"Saved to {save_path}\")\n", - "\n", - " if show:\n", - " plt.show()\n", - " else:\n", - " plt.close()\n", - "\n", - " return fig, ax\n", - "\n", - "\n", - "reduction_threshold = 0.9\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir,\n", - " k_values,\n", - " hidden_dims,\n", - " reduction_threshold=reduction_threshold,\n", - ")\n", - "\n", - "\n", - "plot_convergence_vs_k(\n", - " conv_grid=conv_grid,\n", - " k_values=k_values,\n", - " hidden_dims=hidden_dims,\n", - " save_path=None,\n", - " show=True,\n", - " log_x=False,\n", - " log_y=True,\n", - " reduction_threshold=reduction_threshold,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3e814b11", - "metadata": {}, - "source": [ - "### Curve plot : Normalized Convergence Steps vs Sequence Length for different hidden dimensions\n", - "- x-axis: sequence length\n", - "- y-axis: normalized convergence steps ($\\text{steps} / |G|^k$)\n", - "- different curves for different hidden dimensions $H$" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e116a0e1", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_convergence_vs_k_normalized(\n", - " conv_grid,\n", - " k_values,\n", - " hidden_dims,\n", - " p: int = 100, # Vocabulary size\n", - " batch_size: int = 1000, # Batch size used in training\n", - " save_path=None,\n", - " show=True,\n", - " log_x=False,\n", - " log_y=True,\n", - " reduction_threshold=0.9,\n", - "):\n", - " \"\"\"\n", - " Plot fraction of data space seen to convergence vs sequence length k.\n", - "\n", - " Normalizes steps to convergence by the data space size (p^k) to show\n", - " what fraction of the data space needs to be seen for convergence.\n", - "\n", - " Args:\n", - " conv_grid: 2D array (len(hidden_dims), len(k_values)) with convergence steps\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - " p: Vocabulary size (data space per token)\n", - " batch_size: Batch size used during training\n", - " save_path: Where to save the plot\n", - " show: Whether to display the plot\n", - " log_x: Whether to use log scale for x-axis\n", - " log_y: Whether to use log scale for y-axis\n", - " reduction_threshold: Threshold used for convergence\n", - " \"\"\"\n", - " fig, ax = plt.subplots(figsize=(10, 6))\n", - "\n", - " # Use a nice sequential colormap for different widths\n", - " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(hidden_dims)))\n", - "\n", - " for i, (h, color) in enumerate(zip(hidden_dims, colors)):\n", - " # Extract convergence steps for this hidden dim across all k values\n", - " steps_for_h = conv_grid[i, :]\n", - "\n", - " # Only plot converged points\n", - " converged_mask = ~np.isnan(steps_for_h)\n", - " k_converged = np.array(k_values)[converged_mask]\n", - " steps_converged = steps_for_h[converged_mask]\n", - "\n", - " if len(steps_converged) > 0:\n", - " # Normalize by data space size for each k\n", - " # samples_seen = steps * batch_size\n", - " # fraction = samples_seen / p^k\n", - " fractions = []\n", - " for k_val, steps_val in zip(k_converged, steps_converged):\n", - " data_space_size = p**k_val\n", - " samples_seen = steps_val * batch_size\n", - " fraction = samples_seen / data_space_size\n", - " fractions.append(fraction)\n", - "\n", - " # Plot with line and markers\n", - " ax.plot(\n", - " k_converged,\n", - " fractions,\n", - " color=color,\n", - " marker=\"o\",\n", - " markersize=7,\n", - " linewidth=2.5,\n", - " label=f\"h={h:,}\",\n", - " markeredgewidth=0.5,\n", - " markeredgecolor=\"white\",\n", - " )\n", - "\n", - " # Formatting\n", - " ax.set_xlabel(\"Sequence Length (k)\", fontsize=14)\n", - " ax.set_ylabel(\"Data points seen / $p^k$ to convergence\", fontsize=14)\n", - " ax.set_title(\n", - " f\"Data Efficiency to {reduction_threshold*100}% Convergence\",\n", - " fontsize=16,\n", - " )\n", - "\n", - " if log_y:\n", - " ax.set_yscale(\"log\")\n", - " if log_x:\n", - " ax.set_xscale(\"log\")\n", - " else:\n", - " # Make k values discrete on x-axis\n", - " ax.set_xticks(k_values)\n", - " ax.set_xticklabels(k_values)\n", - "\n", - " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n", - " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if save_path:\n", - " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"Saved to {save_path}\")\n", - "\n", - " if show:\n", - " plt.show()\n", - " else:\n", - " plt.close()\n", - "\n", - " return fig, ax\n", - "\n", - "\n", - "reduction_threshold = 0.9\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir,\n", - " k_values,\n", - " hidden_dims,\n", - " reduction_threshold=reduction_threshold,\n", - ")\n", - "\n", - "\n", - "plot_convergence_vs_k_normalized(\n", - " conv_grid=conv_grid,\n", - " k_values=k_values,\n", - " hidden_dims=hidden_dims,\n", - " p=10, # Your vocabulary size\n", - " batch_size=1000, # Your batch size\n", - " save_path=None,\n", - " show=True,\n", - " log_x=False,\n", - " log_y=True,\n", - " reduction_threshold=reduction_threshold,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e0f758af", - "metadata": {}, - "source": [ - "### Curve plot: Loss vs Training Steps for different sequence lengths, fixed hidden dimension\n", - "- x-axis: # training steps\n", - "- y-axis: training loss\n", - "- different curves for different sequence lengths" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7e809783", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pathlib import Path\n", - "\n", - "\n", - "def plot_loss_curves_fixed_width(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dim: int = 77760,\n", - " seed: int = 0,\n", - " save_path: str = None,\n", - " show: bool = True,\n", - " log_x: bool = True,\n", - " log_y: bool = True,\n", - "):\n", - " \"\"\"\n", - " Plot loss curves for different sequence lengths k with fixed hidden dimension.\n", - "\n", - " Args:\n", - " sweep_dir: Path to sweep directory\n", - " k_values: List of k values to plot (e.g., [2, 3, 4, 5, 6, 7, 8])\n", - " hidden_dim: Fixed hidden dimension (default: 77760)\n", - " seed: Which seed to plot (default: 0)\n", - " save_path: Where to save the plot\n", - " show: Whether to display the plot\n", - " log_x: Whether to use log scale for x-axis\n", - " log_y: Whether to use log scale for y-axis\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " # Create figure\n", - " fig, ax = plt.subplots(figsize=(10, 6))\n", - "\n", - " # Use a nice sequential colormap (plasma, magma, cividis, YlOrRd, etc.)\n", - " colors = plt.cm.plasma(\n", - " np.linspace(0.15, 0.95, len(k_values))\n", - " ) # Avoid too light/dark\n", - "\n", - " for k, color in zip(k_values, colors):\n", - " run_dir = sweep_path / f\"k{k}_h{hidden_dim}\" / f\"seed_{seed}\"\n", - " loss_file = run_dir / \"train_loss_history.npy\"\n", - "\n", - " if not loss_file.exists():\n", - " print(f\"Warning: No data found for k={k}, h={hidden_dim}\")\n", - " continue\n", - "\n", - " # Load loss history\n", - " loss_history = np.load(loss_file)\n", - " steps = np.arange(len(loss_history))\n", - "\n", - " # Plot\n", - " ax.plot(steps, loss_history, color=color, lw=2.5, label=f\"k={k}\")\n", - "\n", - " # Formatting\n", - " ax.set_xlabel(\"Training Steps\", fontsize=14)\n", - " ax.set_ylabel(\"Training Loss\", fontsize=14)\n", - " ax.set_title(f\"Loss vs Training Steps (h={hidden_dim:,})\", fontsize=16)\n", - " if log_x:\n", - " ax.set_xscale(\"log\")\n", - " if log_y:\n", - " ax.set_yscale(\"log\")\n", - " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n", - " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if save_path:\n", - " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"Saved to {save_path}\")\n", - "\n", - " if show:\n", - " plt.show()\n", - " else:\n", - " plt.close()\n", - "\n", - " return fig, ax\n", - "\n", - "\n", - "plot_loss_curves_fixed_width(\n", - " sweep_dir=sweep_dir,\n", - " k_values=[2, 3, 4, 5, 6, 7, 8],\n", - " hidden_dim=6**2,\n", - " seed=0,\n", - " save_path=None,\n", - " show=True,\n", - " log_x=True,\n", - " log_y=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5cd8b97", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_spikiness_metrics_upward_only(loss_history):\n", - " \"\"\"\n", - " Compute spikiness focusing ONLY on upward jumps (loss increases).\n", - "\n", - " This separates:\n", - " - Fast learning (large downward jumps) = STABLE\n", - " - Instability (upward jumps) = UNSTABLE/SPIKY\n", - " \"\"\"\n", - " log_loss = np.log10(loss_history + 1e-10)\n", - " log_changes = np.diff(log_loss) # Can be positive or negative\n", - "\n", - " # Separate upward (bad) from downward (good)\n", - " upward_spikes = log_changes[log_changes > 0] # Loss INCREASES\n", - " downward_drops = log_changes[log_changes < 0] # Loss DECREASES\n", - "\n", - " metrics = {}\n", - "\n", - " # Count how many steps are upward vs downward\n", - " metrics[\"n_upward\"] = len(upward_spikes)\n", - " metrics[\"n_downward\"] = len(downward_drops)\n", - " metrics[\"frac_upward\"] = (\n", - " len(upward_spikes) / len(log_changes) if len(log_changes) > 0 else 0\n", - " )\n", - "\n", - " if len(upward_spikes) > 0:\n", - " # Metrics based ONLY on upward spikes (instability)\n", - " metrics[\"upward_p95\"] = np.percentile(upward_spikes, 95)\n", - " metrics[\"upward_p999\"] = np.percentile(upward_spikes, 99.9)\n", - " metrics[\"upward_max\"] = np.max(upward_spikes)\n", - " metrics[\"upward_mean\"] = np.mean(upward_spikes)\n", - " metrics[\"upward_std\"] = np.std(upward_spikes)\n", - " else:\n", - " # Perfectly monotonic decrease (never went up!)\n", - " metrics[\"upward_p95\"] = 0.0\n", - " metrics[\"upward_p999\"] = 0.0\n", - " metrics[\"upward_max\"] = 0.0\n", - " metrics[\"upward_mean\"] = 0.0\n", - " metrics[\"upward_std\"] = 0.0\n", - "\n", - " if len(downward_drops) > 0:\n", - " # For reference: how fast is it learning?\n", - " metrics[\"downward_p95\"] = np.percentile(\n", - " np.abs(downward_drops), 95\n", - " ) # Large drops = fast\n", - " metrics[\"downward_mean\"] = np.mean(np.abs(downward_drops))\n", - " else:\n", - " metrics[\"downward_p95\"] = 0.0\n", - " metrics[\"downward_mean\"] = 0.0\n", - "\n", - " # Ratio: upward spikes vs downward progress\n", - " if metrics[\"downward_mean\"] > 0:\n", - " metrics[\"spike_to_progress_ratio\"] = (\n", - " metrics[\"upward_mean\"] / metrics[\"downward_mean\"]\n", - " )\n", - " else:\n", - " metrics[\"spike_to_progress_ratio\"] = (\n", - " np.inf if metrics[\"upward_mean\"] > 0 else 0.0\n", - " )\n", - "\n", - " # Late-stage upward spikes (last 20%)\n", - " cutoff = int(0.8 * len(log_changes))\n", - " late_changes = log_changes[cutoff:]\n", - " late_upward = late_changes[late_changes > 0]\n", - "\n", - " if len(late_upward) > 0:\n", - " metrics[\"late_upward_p95\"] = np.percentile(late_upward, 95)\n", - " metrics[\"late_upward_max\"] = np.max(late_upward)\n", - " else:\n", - " metrics[\"late_upward_p95\"] = 0.0\n", - " metrics[\"late_upward_max\"] = 0.0\n", - "\n", - " return metrics\n", - "\n", - "\n", - "def plot_loss_curves_with_upward_metrics(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dim: int = 36,\n", - " seed: int = 0,\n", - " save_path: str = None,\n", - " show: bool = True,\n", - " log_x: bool = True,\n", - " log_y: bool = True,\n", - "):\n", - " \"\"\"Plot loss curves with metrics focused on upward spikes only.\"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " fig, ax = plt.subplots(figsize=(10, 6))\n", - " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(k_values)))\n", - "\n", - " metrics_data = []\n", - "\n", - " for k, color in zip(k_values, colors):\n", - " run_dir = sweep_path / f\"k{k}_h{hidden_dim}\" / f\"seed_{seed}\"\n", - " loss_file = run_dir / \"train_loss_history.npy\"\n", - "\n", - " if not loss_file.exists():\n", - " print(f\"Warning: No data found for k={k}, h={hidden_dim}\")\n", - " continue\n", - "\n", - " loss_history = np.load(loss_file)\n", - " steps = np.arange(len(loss_history))\n", - "\n", - " ax.plot(steps, loss_history, color=color, lw=2.5, label=f\"k={k}\")\n", - "\n", - " # Compute upward-only metrics\n", - " metrics = compute_spikiness_metrics_upward_only(loss_history)\n", - " metrics[\"k\"] = k\n", - " metrics[\"n_steps\"] = len(loss_history)\n", - " metrics_data.append(metrics)\n", - "\n", - " # Formatting\n", - " ax.set_xlabel(\"Training Steps\", fontsize=14)\n", - " ax.set_ylabel(\"Training Loss\", fontsize=14)\n", - " ax.set_title(f\"Loss vs Training Steps (h={hidden_dim:,})\", fontsize=16)\n", - " if log_x:\n", - " ax.set_xscale(\"log\")\n", - " if log_y:\n", - " ax.set_yscale(\"log\")\n", - " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n", - " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if save_path:\n", - " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"Saved to {save_path}\")\n", - "\n", - " if show:\n", - " plt.show()\n", - " else:\n", - " plt.close()\n", - "\n", - " # Print metrics\n", - " df = pd.DataFrame(metrics_data)\n", - " col_order = [\n", - " \"k\",\n", - " \"n_steps\",\n", - " \"frac_upward\",\n", - " \"upward_p95\",\n", - " \"upward_p999\",\n", - " \"upward_max\",\n", - " \"late_upward_p95\",\n", - " \"spike_to_progress_ratio\",\n", - " \"downward_p95\",\n", - " ]\n", - " df = df[col_order]\n", - "\n", - " print(\"\\n\" + \"=\" * 100)\n", - " print(f\"UPWARD SPIKE METRICS (h={hidden_dim})\")\n", - " print(\"=\" * 100)\n", - " print(\"\\nMetric Definitions:\")\n", - " print(\" frac_upward : Fraction of steps where loss INCREASED\")\n", - " print(\" upward_p95 : 95th percentile of upward jumps (instability)\")\n", - " print(\" upward_p999 : 99.9th percentile of upward jumps\")\n", - " print(\" upward_max : Worst upward spike\")\n", - " print(\" late_upward_p95 : 95th percentile of upward jumps in last 20%\")\n", - " print(\n", - " \" spike_to_progress_ratio : Mean upward / mean downward (higher = more unstable)\"\n", - " )\n", - " print(\n", - " \" downward_p95 : 95th percentile of downward jumps (learning speed)\"\n", - " )\n", - " print(\"\\n\" + \"-\" * 100)\n", - "\n", - " pd.set_option(\"display.max_columns\", None)\n", - " pd.set_option(\"display.width\", None)\n", - " pd.set_option(\"display.float_format\", \"{:.4f}\".format)\n", - " print(df.to_string(index=False))\n", - " print(\"=\" * 100 + \"\\n\")\n", - "\n", - " return fig, ax, df\n", - "\n", - "\n", - "# Call it:\n", - "fig, ax, metrics_df = plot_loss_curves_with_upward_metrics(\n", - " sweep_dir=sweep_dir,\n", - " k_values=[2, 3, 4, 5, 6, 7, 8],\n", - " hidden_dim=6**6,\n", - " seed=0,\n", - " save_path=None,\n", - " show=True,\n", - " log_x=True,\n", - " log_y=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bd0b5026", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_spikiness(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dims: list,\n", - "):\n", - " \"\"\"\n", - " Compute fraction of training steps where loss increased (instability).\n", - "\n", - " Returns:\n", - " grid: 2D array with mean frac_upward across seeds\n", - " std_grid: 2D array with standard deviations\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " frac_upwards = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " log_loss = np.log10(loss_history + 1e-10)\n", - " log_changes = np.diff(log_loss)\n", - "\n", - " # Fraction of steps where loss went UP\n", - " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n", - " frac_upwards.append(frac_upward)\n", - "\n", - " if frac_upwards:\n", - " grid[i, j] = np.mean(frac_upwards)\n", - " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6d858f63", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute stability grid\n", - "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n", - " sweep_dir, k_values, hidden_dims\n", - ")\n", - "\n", - "# Plot\n", - "plt.figure(figsize=(8, 6.5))\n", - "plt.imshow(stability_grid, aspect=\"equal\", cmap=\"viridis\") # , norm=LogNorm())\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels)\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "plt.gca().invert_yaxis()\n", - "plt.colorbar(label=\"Training Spikiness\")\n", - "plt.title(\"Training Spikiness\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "298e758e", - "metadata": {}, - "source": [ - "# Varying group size (num frequencies)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "74cd5103", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_convergence_3d(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dims: list,\n", - " num_frequencies: int,\n", - " reduction_threshold: float = 0.99,\n", - "):\n", - " \"\"\"\n", - " Load sweep results and compute steps to convergence for 3D sweeps over k, h, and f.\n", - "\n", - " This function is designed for sweeps that include num_frequencies as a parameter,\n", - " using directory naming format: k{k}_h{h}_f{f}\n", - "\n", - " Convergence is defined as reaching `reduction_threshold` loss reduction\n", - " (e.g., 0.99 = 99% reduction from initial loss).\n", - "\n", - " If convergence is not reached, the grid point is set to NaN (blacked out).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values\n", - " num_frequencies: Number of frequencies (f parameter)\n", - " reduction_threshold: Fraction of loss reduction to consider converged\n", - "\n", - " Returns:\n", - " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, k in enumerate(k_values):\n", - " exp_name = f\"k{k}_h{h}_f{num_frequencies}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Collect convergence steps from all seeds\n", - " convergence_steps = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " initial_loss = loss_history[0]\n", - "\n", - " if initial_loss > 0:\n", - " # Compute reduction at each step\n", - " reductions = 1 - loss_history / initial_loss\n", - "\n", - " # Find first step where reduction >= threshold\n", - " converged_mask = reductions >= reduction_threshold\n", - " if np.any(converged_mask):\n", - " step = np.argmax(converged_mask) # First True\n", - " convergence_steps.append(step)\n", - " # else: Never converged - don't add to list\n", - "\n", - " if convergence_steps:\n", - " grid[i, j] = np.mean(convergence_steps)\n", - " std_grid[i, j] = (\n", - " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n", - " )\n", - " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "markdown", - "id": "418d1ac0", - "metadata": {}, - "source": [ - "## num_freq = 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "01cf6f19", - "metadata": {}, - "outputs": [], - "source": [ - "# Example using the new 3D sweep (k, h, f)\n", - "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n", - "\n", - "# Define parameter values for the new sweep\n", - "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n", - "num_frequencies = 2 # Set this to the frequency value you want to visualize\n", - "\n", - "# Load convergence data for a specific frequency\n", - "reduction_threshold = 0.5\n", - "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n", - " new_sweep_dir,\n", - " k_values_new,\n", - " hidden_dims_new,\n", - " reduction_threshold=reduction_threshold,\n", - " num_frequencies=num_frequencies,\n", - ")\n", - "\n", - "# Plot the heatmap\n", - "plt.figure(figsize=(8, 6))\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "plt.xticks(range(len(k_values_new)), k_values_new)\n", - "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "ec9b06b2", - "metadata": {}, - "source": [ - "## num_freq = 3" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32e86647", - "metadata": {}, - "outputs": [], - "source": [ - "# Example using the new 3D sweep (k, h, f)\n", - "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n", - "\n", - "# Define parameter values for the new sweep\n", - "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n", - "num_frequencies = 3 # Set this to the frequency value you want to visualize\n", - "\n", - "# Load convergence data for a specific frequency\n", - "reduction_threshold = 0.5\n", - "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n", - " new_sweep_dir,\n", - " k_values_new,\n", - " hidden_dims_new,\n", - " reduction_threshold=reduction_threshold,\n", - " num_frequencies=num_frequencies,\n", - ")\n", - "\n", - "# Plot the heatmap\n", - "plt.figure(figsize=(8, 6))\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "plt.xticks(range(len(k_values_new)), k_values_new)\n", - "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "44a4e9c7", - "metadata": {}, - "source": [ - "## num_freq = 4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e77f1a4e", - "metadata": {}, - "outputs": [], - "source": [ - "# Example using the new 3D sweep (k, h, f)\n", - "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n", - "\n", - "# Define parameter values for the new sweep\n", - "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n", - "num_frequencies = 4 # Set this to the frequency value you want to visualize\n", - "\n", - "# Load convergence data for a specific frequency\n", - "reduction_threshold = 0.5\n", - "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n", - " new_sweep_dir,\n", - " k_values_new,\n", - " hidden_dims_new,\n", - " reduction_threshold=reduction_threshold,\n", - " num_frequencies=num_frequencies,\n", - ")\n", - "\n", - "# Plot the heatmap\n", - "plt.figure(figsize=(8, 6))\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "plt.xticks(range(len(k_values_new)), k_values_new)\n", - "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "aa9ae86c", - "metadata": {}, - "source": [ - "## num_freq = 5" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ce18ab9", - "metadata": {}, - "outputs": [], - "source": [ - "# Example using the new 3D sweep (k, h, f)\n", - "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n", - "\n", - "# Define parameter values for the new sweep\n", - "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n", - "num_frequencies = 5 # Set this to the frequency value you want to visualize\n", - "\n", - "# Load convergence data for a specific frequency\n", - "reduction_threshold = 0.5\n", - "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n", - " new_sweep_dir,\n", - " k_values_new,\n", - " hidden_dims_new,\n", - " reduction_threshold=reduction_threshold,\n", - " num_frequencies=num_frequencies,\n", - ")\n", - "\n", - "# Plot the heatmap\n", - "plt.figure(figsize=(8, 6))\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\")\n", - "plt.ylabel(\"Hidden Dimension\")\n", - "plt.xticks(range(len(k_values_new)), k_values_new)\n", - "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "f8585f79", - "metadata": {}, - "source": [ - "### Grid plot: Convergence vs k for different num_frequencies, across different hidden dimensions\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "97c43fc0", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_convergence_vs_k_grid_by_frequency(\n", - " sweep_dir: str,\n", - " k_values: list,\n", - " hidden_dims: list,\n", - " num_frequencies_values: list,\n", - " reduction_threshold: float = 0.99,\n", - " figsize=(18, 12),\n", - " log_x=False,\n", - " log_y=True,\n", - " save_path=None,\n", - " show=True,\n", - "):\n", - " \"\"\"\n", - " Create a grid of plots showing convergence vs k for different frequencies.\n", - "\n", - " Each subplot corresponds to a different hidden dimension.\n", - " Within each subplot, different curves represent different num_frequencies values.\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k_values: List of k (sequence length) values\n", - " hidden_dims: List of hidden dimension values (one subplot per hidden dim)\n", - " num_frequencies_values: List of num_frequencies values to compare\n", - " reduction_threshold: Threshold for convergence definition\n", - " figsize: Figure size tuple\n", - " log_x: Whether to use log scale for x-axis\n", - " log_y: Whether to use log scale for y-axis\n", - " save_path: Where to save the plot\n", - " show: Whether to display the plot\n", - "\n", - " Returns:\n", - " fig, axes: Matplotlib figure and axes objects\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " # Determine grid layout (2x3 or 3x2 based on number of hidden dims)\n", - " n_plots = len(hidden_dims)\n", - " if n_plots == 6:\n", - " nrows, ncols = 2, 3\n", - " elif n_plots == 4:\n", - " nrows, ncols = 2, 2\n", - " else:\n", - " # General case: aim for squarish layout\n", - " ncols = int(np.ceil(np.sqrt(n_plots)))\n", - " nrows = int(np.ceil(n_plots / ncols))\n", - "\n", - " fig, axes = plt.subplots(nrows, ncols, figsize=figsize)\n", - " axes_flat = axes.flatten() if n_plots > 1 else [axes]\n", - "\n", - " # Use a nice colormap for different frequencies\n", - " colors = plt.cm.viridis(np.linspace(0.15, 0.85, len(num_frequencies_values)))\n", - "\n", - " for idx, h in enumerate(hidden_dims):\n", - " ax = axes_flat[idx]\n", - "\n", - " # For each frequency, plot convergence vs k\n", - " for f_idx, num_freq in enumerate(num_frequencies_values):\n", - " convergence_steps_for_k = []\n", - " k_values_converged = []\n", - "\n", - " for k in k_values:\n", - " exp_name = f\"k{k}_h{h}_f{num_freq}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Collect convergence steps from all seeds\n", - " convergence_steps = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " initial_loss = loss_history[0]\n", - "\n", - " if initial_loss > 0:\n", - " reductions = 1 - loss_history / initial_loss\n", - " converged_mask = reductions >= reduction_threshold\n", - " if np.any(converged_mask):\n", - " step = np.argmax(converged_mask)\n", - " convergence_steps.append(step)\n", - "\n", - " # Take mean across seeds if any converged\n", - " if convergence_steps:\n", - " k_values_converged.append(k)\n", - " convergence_steps_for_k.append(np.mean(convergence_steps))\n", - "\n", - " # Plot this frequency's curve\n", - " if len(k_values_converged) > 0:\n", - " ax.plot(\n", - " k_values_converged,\n", - " convergence_steps_for_k,\n", - " color=colors[f_idx],\n", - " marker=\"o\",\n", - " markersize=6,\n", - " linewidth=2,\n", - " label=f\"f={num_freq}\",\n", - " markeredgewidth=0.5,\n", - " markeredgecolor=\"white\",\n", - " )\n", - "\n", - " # Formatting for this subplot\n", - " ax.set_xlabel(\"Sequence Length (k)\", fontsize=11)\n", - " ax.set_ylabel(\"Steps to Convergence\", fontsize=11)\n", - " ax.set_title(f\"h = {h:,}\", fontsize=13, fontweight=\"bold\")\n", - "\n", - " if log_y:\n", - " ax.set_yscale(\"log\")\n", - " if log_x:\n", - " ax.set_xscale(\"log\")\n", - " else:\n", - " # Make k values discrete on x-axis\n", - " ax.set_xticks(k_values)\n", - " ax.set_xticklabels(k_values)\n", - "\n", - " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n", - " ax.legend(fontsize=9, framealpha=0.9, loc=\"best\")\n", - "\n", - " # Hide any unused subplots\n", - " for idx in range(n_plots, len(axes_flat)):\n", - " axes_flat[idx].axis(\"off\")\n", - "\n", - " # Overall title\n", - " fig.suptitle(\n", - " f\"Convergence vs Sequence Length by Number of Frequencies\\n\"\n", - " f\"({reduction_threshold*100:.0f}% Loss Reduction Threshold)\",\n", - " fontsize=16,\n", - " fontweight=\"bold\",\n", - " y=0.995,\n", - " )\n", - "\n", - " plt.tight_layout()\n", - "\n", - " if save_path:\n", - " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n", - " print(f\"Saved to {save_path}\")\n", - "\n", - " if show:\n", - " plt.show()\n", - " else:\n", - " plt.close()\n", - "\n", - " return fig, axes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "acd00a0a", - "metadata": {}, - "outputs": [], - "source": [ - "# Example usage: Create grid plot\n", - "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n", - "\n", - "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n", - "num_frequencies_values = [2, 3, 4, 5] # All frequency values to compare\n", - "\n", - "plot_convergence_vs_k_grid_by_frequency(\n", - " sweep_dir=new_sweep_dir,\n", - " k_values=k_values_new,\n", - " hidden_dims=hidden_dims_new,\n", - " num_frequencies_values=num_frequencies_values,\n", - " reduction_threshold=0.5,\n", - " figsize=(18, 12),\n", - " log_x=False,\n", - " log_y=True,\n", - " save_path=None,\n", - " show=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "85c094f2", - "metadata": {}, - "source": [ - "## p=2 experiments" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3a80c6fc", - "metadata": {}, - "outputs": [], - "source": [ - "# Define sweep directory and parameters\n", - "sweep_dir = \"/home/facosta/group-agf/sweeps/p2_scaling_sweep_20251215_205347\"\n", - "\n", - "# Parameters from p2_scaling_sweep.yaml\n", - "k_values = [2, 3, 4, 5, 6, 7, 8]\n", - "hidden_dims = [\n", - " 4,\n", - " 8,\n", - " 16,\n", - " 32,\n", - " 64,\n", - " 128,\n", - " 256,\n", - " 512,\n", - " 1024,\n", - " 2048,\n", - " 4096,\n", - " 8192,\n", - " 16384,\n", - " 32768,\n", - " 65536,\n", - "]\n", - "\n", - "# Load convergence data\n", - "reduction_threshold = 0.6 # 90% loss reduction\n", - "conv_grid, conv_std = load_sweep_results_grid_convergence(\n", - " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n", - ")\n", - "\n", - "\n", - "from matplotlib.colors import LogNorm\n", - "\n", - "# Plot the heatmap\n", - "plt.figure(figsize=(10, 8))\n", - "cmap = plt.cm.viridis_r.copy()\n", - "cmap.set_bad(color=\"black\")\n", - "plt.imshow(conv_grid, aspect=\"auto\", cmap=cmap, norm=LogNorm())\n", - "\n", - "plt.xlabel(\"Sequence Length (k)\", fontsize=12)\n", - "plt.ylabel(\"Hidden Dimension (h)\", fontsize=12)\n", - "plt.xticks(range(len(k_values)), k_values)\n", - "\n", - "# Create y-tick labels with both power notation and actual values for larger dims\n", - "ytick_labels = []\n", - "for h in hidden_dims:\n", - " if h >= 1024:\n", - " power = int(np.log2(h))\n", - " ytick_labels.append(f\"$2^{{{power}}}$ ({h:,})\")\n", - " else:\n", - " ytick_labels.append(f\"{h}\")\n", - "\n", - "plt.yticks(range(len(hidden_dims)), ytick_labels, fontsize=9)\n", - "plt.gca().invert_yaxis()\n", - "\n", - "# Add theoretical boundary line (h > p^(k-1), where p=2)\n", - "# For p=2: boundary at h = 2^(k-1), so h=1,2,4,8,16,32,64\n", - "x_step = np.arange(len(k_values)) - 0.5\n", - "# Find y index where h = 2^(k-1) for each k\n", - "y_boundary = []\n", - "for i, k in enumerate(k_values):\n", - " boundary_h = 2 ** (k - 1)\n", - " # Find closest hidden_dim index\n", - " try:\n", - " y_idx = hidden_dims.index(boundary_h)\n", - " except ValueError:\n", - " # If exact match not found, find closest\n", - " y_idx = np.argmin(np.abs(np.array(hidden_dims) - boundary_h))\n", - " y_boundary.append(y_idx)\n", - "\n", - "plt.step(\n", - " x_step,\n", - " y_boundary,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=r\"Theory boundary ($h > 2^{k-1}$)\",\n", - ")\n", - "\n", - "plt.legend(loc=\"upper left\", fontsize=11, frameon=True)\n", - "plt.colorbar(label=f\"Steps to {reduction_threshold*100:.0f}% Convergence\")\n", - "plt.title(\n", - " f\"p=2 Scaling: Steps to {reduction_threshold*100:.0f}% Convergence\\n(black = did not converge)\",\n", - " fontsize=13,\n", - " fontweight=\"bold\",\n", - ")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Print summary statistics\n", - "n_not_converged = np.sum(np.isnan(conv_grid))\n", - "n_converged = np.sum(~np.isnan(conv_grid))\n", - "total = conv_grid.size\n", - "\n", - "print(f\"\\n{'='*60}\")\n", - "print(f\"CONVERGENCE SUMMARY (p=2 scaling)\")\n", - "print(f\"{'='*60}\")\n", - "print(f\"Converged: {n_converged:3d} ({100*n_converged/total:.1f}%)\")\n", - "print(f\"Did not converge: {n_not_converged:3d} ({100*n_not_converged/total:.1f}%)\")\n", - "print(f\"Total experiments: {total:3d}\")\n", - "print(f\"{'='*60}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1e10a36", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/seq_mlp_group_size.ipynb b/notebooks/seq_mlp_group_size.ipynb deleted file mode 100644 index a30e7ec..0000000 --- a/notebooks/seq_mlp_group_size.ipynb +++ /dev/null @@ -1,1393 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "c8c5c4b6", - "metadata": {}, - "source": [ - "# MLP Scaling: $H$ vs $|G|$ \n", - "\n", - "Hidden neurons vs group size scaling experiments." - ] - }, - { - "cell_type": "markdown", - "id": "155908c2", - "metadata": {}, - "source": [ - "## Set up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fc4c5b6", - "metadata": {}, - "outputs": [], - "source": [ - "# autoreload\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "# jupyter black formatter\n", - "%load_ext jupyter_black\n", - "\n", - "import subprocess\n", - "import os\n", - "import sys\n", - "\n", - "gitroot_path = subprocess.check_output(\n", - " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n", - ").strip()\n", - "\n", - "os.chdir(gitroot_path)\n", - "print(\"Working directory: \", os.getcwd())\n", - "\n", - "if gitroot_path not in sys.path:\n", - " sys.path.insert(0, gitroot_path)\n", - "print(\"Directory added to path: \", gitroot_path)\n", - "\n", - "import yaml\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "markdown", - "id": "9831010d", - "metadata": {}, - "source": [ - "## Specify experiment directory" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9f8fc25", - "metadata": {}, - "outputs": [], - "source": [ - "# sweep_dir = \"/home/facosta/group-agf/sweeps/onehot_scaling_sweep_20251215_175955\"\n", - "sweep_dir = \"/home/facosta/group-agf/sweep_results/onehot_scaling_sweep_20260112_022012\"\n", - "print(os.path.exists(sweep_dir))" - ] - }, - { - "cell_type": "markdown", - "id": "d8342c22", - "metadata": {}, - "source": [ - "### Steps to Convergence" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc6dd932", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_convergence_p_h(\n", - " sweep_dir: str,\n", - " k: int,\n", - " p_values: list,\n", - " hidden_dims: list,\n", - " reduction_threshold: float = 0.99,\n", - " max_p: int = None,\n", - "):\n", - " \"\"\"\n", - " Load sweep results and compute steps to convergence for p vs hidden_dim sweep.\n", - "\n", - " Updated for experiment naming: k{k}_p{p}_h{h}\n", - " Only loads completed experiments (checks for run_summary.yaml).\n", - "\n", - " Convergence is defined as reaching `reduction_threshold` loss reduction\n", - " (e.g., 0.99 = 99% reduction from initial loss).\n", - "\n", - " If convergence is not reached, the grid point is set to NaN (blacked out).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k: Sequence length parameter (2, 3, or 4)\n", - " p_values: List of p (group size) values\n", - " hidden_dims: List of hidden dimension values\n", - " reduction_threshold: Fraction of loss reduction to consider converged\n", - " max_p: Maximum p value to include (filters incomplete experiments)\n", - "\n", - " Returns:\n", - " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n", - " Shape: (len(hidden_dims), len(p_values))\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, p in enumerate(p_values):\n", - " # Filter by max_p if specified\n", - " if max_p is not None and p > max_p:\n", - " continue\n", - "\n", - " exp_name = f\"k{k}_p{p}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Check if experiment is completed (has run_summary.yaml)\n", - " seed_dir = exp_dir / \"seed_0\"\n", - " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n", - " continue # Skip incomplete experiments\n", - "\n", - " # Collect convergence steps from all seeds\n", - " convergence_steps = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " initial_loss = loss_history[0]\n", - "\n", - " if initial_loss > 0:\n", - " # Compute reduction at each step\n", - " reductions = 1 - loss_history / initial_loss\n", - "\n", - " # Find first step where reduction >= threshold\n", - " converged_mask = reductions >= reduction_threshold\n", - " if np.any(converged_mask):\n", - " step = np.argmax(converged_mask) # First True\n", - " convergence_steps.append(step)\n", - " # else: Never converged - don't add to list\n", - "\n", - " if convergence_steps:\n", - " grid[i, j] = np.mean(convergence_steps)\n", - " std_grid[i, j] = (\n", - " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n", - " )\n", - " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a87f24d", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_final_loss_p_h(\n", - " sweep_dir: str,\n", - " k: int,\n", - " p_values: list,\n", - " hidden_dims: list,\n", - " max_p: int = None,\n", - "):\n", - " \"\"\"\n", - " Load sweep results and compute final training loss for p vs hidden_dim sweep.\n", - "\n", - " Updated for experiment naming: k{k}_p{p}_h{h}\n", - " Only loads completed experiments (checks for run_summary.yaml).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k: Sequence length parameter (2, 3, or 4)\n", - " p_values: List of p (group size) values\n", - " hidden_dims: List of hidden dimension values\n", - " max_p: Maximum p value to include (filters incomplete experiments)\n", - "\n", - " Returns:\n", - " grid: 2D array with mean final training loss (NaN if experiment incomplete)\n", - " Shape: (len(hidden_dims), len(p_values))\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, p in enumerate(p_values):\n", - " # Filter by max_p if specified\n", - " if max_p is not None and p > max_p:\n", - " continue\n", - "\n", - " exp_name = f\"k{k}_p{p}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Check if experiment is completed (has run_summary.yaml)\n", - " seed_dir = exp_dir / \"seed_0\"\n", - " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n", - " continue # Skip incomplete experiments\n", - "\n", - " # Collect final losses from all seeds\n", - " final_losses = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " if len(loss_history) > 0:\n", - " final_loss = loss_history[-1] # Last value\n", - " final_losses.append(final_loss)\n", - "\n", - " if final_losses:\n", - " grid[i, j] = np.mean(final_losses)\n", - " std_grid[i, j] = np.std(final_losses) if len(final_losses) > 1 else 0.0\n", - " # else: No seeds found - grid[i,j] remains NaN (blacked out)\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3bb53f80", - "metadata": {}, - "outputs": [], - "source": [ - "def load_training_loss_curves_p(\n", - " sweep_dir: str,\n", - " k: int,\n", - " hidden_dim: int,\n", - " p_values: list,\n", - "):\n", - " \"\"\"\n", - " Load training loss histories for different group sizes (p) with fixed k and hidden_dim.\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k: Sequence length parameter (fixed)\n", - " hidden_dim: Hidden dimension (fixed)\n", - " p_values: List of p (group size) values to plot\n", - "\n", - " Returns:\n", - " curves: Dictionary mapping p -> list of loss histories (one per seed)\n", - " Each loss history is a numpy array\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " curves = {}\n", - "\n", - " for p in p_values:\n", - " exp_name = f\"k{k}_p{p}_h{hidden_dim}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Check if experiment is completed\n", - " seed_dir = exp_dir / \"seed_0\"\n", - " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n", - " continue # Skip incomplete experiments\n", - "\n", - " # Collect loss histories from all seeds\n", - " loss_histories = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " if len(loss_history) > 0:\n", - " loss_histories.append(loss_history)\n", - "\n", - " if loss_histories:\n", - " curves[p] = loss_histories\n", - "\n", - " return curves" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf14dee1", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_final_val_loss_p_h(\n", - " sweep_dir: str,\n", - " k: int,\n", - " p_values: list,\n", - " hidden_dims: list,\n", - " max_p: int = None,\n", - "):\n", - " \"\"\"\n", - " Load sweep results and compute final validation loss for p vs hidden_dim sweep.\n", - "\n", - " Updated for experiment naming: k{k}_p{p}_h{h}\n", - " Only loads completed experiments (checks for run_summary.yaml).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k: Sequence length parameter (2, 3, or 4)\n", - " p_values: List of p (group size) values\n", - " hidden_dims: List of hidden dimension values\n", - " max_p: Maximum p value to include (filters incomplete experiments)\n", - "\n", - " Returns:\n", - " grid: 2D array with mean final validation loss (NaN if experiment incomplete)\n", - " Shape: (len(hidden_dims), len(p_values))\n", - " std_grid: 2D array with standard deviations across seeds\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, p in enumerate(p_values):\n", - " # Filter by max_p if specified\n", - " if max_p is not None and p > max_p:\n", - " continue\n", - "\n", - " exp_name = f\"k{k}_p{p}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Check if experiment is completed (has run_summary.yaml)\n", - " seed_dir = exp_dir / \"seed_0\"\n", - " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n", - " continue # Skip incomplete experiments\n", - "\n", - " # Collect final validation losses from all seeds\n", - " final_losses = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"val_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " if len(loss_history) > 0:\n", - " final_loss = loss_history[-1] # Last value\n", - " final_losses.append(final_loss)\n", - "\n", - " if final_losses:\n", - " grid[i, j] = np.mean(final_losses)\n", - " std_grid[i, j] = np.std(final_losses) if len(final_losses) > 1 else 0.0\n", - " # else: No seeds found - grid[i,j] remains NaN (blacked out)\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42ce6ffd", - "metadata": {}, - "outputs": [], - "source": [ - "# Define parameter values from the sweep config\n", - "# Filter to p <= 55 for completed experiments\n", - "p_values = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]\n", - "hidden_dims = [80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120]\n", - "k_values = [2, 3] # , 4] # Different k values to plot separately" - ] - }, - { - "cell_type": "markdown", - "id": "7bf99dee", - "metadata": {}, - "source": [ - "### Plot steps to convergence grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "522570f5", - "metadata": {}, - "outputs": [], - "source": [ - "# Load convergence data for each k value separately\n", - "reduction_threshold = 0.90\n", - "max_p = 70 # Only visualize completed experiments (p <= 55)\n", - "\n", - "from matplotlib.colors import LogNorm\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " conv_grid, conv_std = load_sweep_results_grid_convergence_p_h(\n", - " sweep_dir,\n", - " k,\n", - " p_values,\n", - " hidden_dims,\n", - " reduction_threshold=reduction_threshold,\n", - " max_p=max_p,\n", - " )\n", - "\n", - " # Filter p values - only show p <= max_p\n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Plot convergence heatmap: p (group size) vs hidden_dim\n", - " plt.figure(figsize=(12, 8))\n", - " cmap = plt.cm.viridis_r.copy()\n", - " cmap.set_bad(color=\"black\")\n", - " # Set extent to align cells with tick positions\n", - " # extent: [left, right, bottom, top] in data coordinates\n", - " plt.imshow(\n", - " conv_grid[:, : len(p_values_filtered)],\n", - " aspect=\"equal\",\n", - " cmap=cmap,\n", - " norm=LogNorm(),\n", - " )\n", - "\n", - " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - " plt.xticks(\n", - " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n", - " )\n", - "\n", - " # Set y-axis ticks (hidden dimensions)\n", - " plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - " plt.gca().invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - "\n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1)) * reduction_threshold\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= upper_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1) * reduction_threshold\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= lower_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=4,\n", - " linestyle=\"-\",\n", - " label=f\"Upper boundary ($H$ = $(k+1) \\\\cdot 2^{{k-1}} |G| x {reduction_threshold}$ = {upper_boundary_coeff * reduction_threshold} * |G|) \",\n", - " )\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"white\",\n", - " linewidth=4,\n", - " linestyle=\"-\",\n", - " label=f\"Lower boundary ($H$ = $2^{{k-1}} |G| x {reduction_threshold}$ = {lower_boundary_coeff * reduction_threshold} * |G|) \",\n", - " )\n", - "\n", - " # Place legend outside the plot area\n", - " plt.legend(\n", - " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n", - " )\n", - "\n", - " plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n", - " plt.title(\n", - " f\"Steps to {reduction_threshold*100}% Convergence: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = did not converge, p ≤ {max_p})\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "28e479df", - "metadata": {}, - "source": [ - "### Final Training Loss\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "06fb5f82", - "metadata": {}, - "outputs": [], - "source": [ - "# Load final training loss data for each k value separately\n", - "max_p = 70 # Only visualize completed experiments (p <= 55)\n", - "\n", - "from matplotlib.colors import LogNorm\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " loss_grid, loss_std = load_sweep_results_grid_final_loss_p_h(\n", - " sweep_dir,\n", - " k,\n", - " p_values,\n", - " hidden_dims,\n", - " max_p=max_p,\n", - " )\n", - "\n", - " # Filter p values - only show p <= max_p\n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Plot final loss heatmap: p (group size) vs hidden_dim\n", - " plt.figure(figsize=(12, 8))\n", - " cmap = plt.cm.viridis_r.copy()\n", - " cmap.set_bad(color=\"black\")\n", - " # Set extent to align cells with tick positions\n", - " # extent: [left, right, bottom, top] in data coordinates\n", - " plt.imshow(\n", - " loss_grid[:, : len(p_values_filtered)],\n", - " aspect=\"equal\",\n", - " cmap=cmap,\n", - " norm=LogNorm(),\n", - " )\n", - "\n", - " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - " plt.xticks(\n", - " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n", - " )\n", - "\n", - " # Set y-axis ticks (hidden dimensions)\n", - " plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - " plt.gca().invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - "\n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= upper_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1)\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= lower_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"magenta\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " # Place legend outside the plot area\n", - " plt.legend(\n", - " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n", - " )\n", - "\n", - " plt.colorbar(label=\"Final Training Loss\")\n", - " plt.title(\n", - " f\"Final Training Loss: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$)\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "93d367cb", - "metadata": {}, - "source": [ - "### Training Loss Curves by Group Size\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5ed8f9c0", - "metadata": {}, - "outputs": [], - "source": [ - "# Plot training loss curves for different group sizes\n", - "# Specify the hidden dimension to use\n", - "hidden_dim = 160 # Change this to plot different hidden dimensions\n", - "\n", - "# Use all available p values (or filter as needed)\n", - "p_values_to_plot = [p for p in p_values if p <= 55] # Adjust max_p as needed\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " # Load training loss curves for different p values\n", - " curves = load_training_loss_curves_p(\n", - " sweep_dir,\n", - " k,\n", - " hidden_dim,\n", - " p_values_to_plot,\n", - " )\n", - "\n", - " if not curves:\n", - " print(f\"No data found for k={k}, H={hidden_dim}\")\n", - " continue\n", - "\n", - " # Create plot\n", - " plt.figure(figsize=(10, 8))\n", - "\n", - " # Plot each group size as a separate curve\n", - " # Use a colormap to distinguish different p values\n", - " colors = plt.cm.viridis(np.linspace(0, 1, len(curves)))\n", - "\n", - " for i, (p, loss_histories) in enumerate(sorted(curves.items())):\n", - " # Plot mean curve with shaded error bars\n", - " # Find the maximum length to align all curves\n", - " max_len = max(len(hist) for hist in loss_histories)\n", - "\n", - " # Pad shorter histories with NaN or last value\n", - " aligned_histories = []\n", - " for hist in loss_histories:\n", - " if len(hist) < max_len:\n", - " padded = np.full(max_len, np.nan)\n", - " padded[: len(hist)] = hist\n", - " aligned_histories.append(padded)\n", - " else:\n", - " aligned_histories.append(hist)\n", - "\n", - " aligned_histories = np.array(aligned_histories)\n", - "\n", - " # Compute mean and std across seeds\n", - " mean_loss = np.nanmean(aligned_histories, axis=0)\n", - " std_loss = np.nanstd(aligned_histories, axis=0)\n", - "\n", - " # Create step array (1-indexed for log scale)\n", - " steps = np.arange(1, len(mean_loss) + 1)\n", - "\n", - " # Plot mean curve\n", - " plt.loglog(\n", - " steps,\n", - " mean_loss,\n", - " color=colors[i],\n", - " linewidth=2,\n", - " label=f\"$|G|={p}$\",\n", - " )\n", - "\n", - " # Plot shaded error region (optional, can be commented out if too cluttered)\n", - " # plt.fill_between(\n", - " # steps,\n", - " # mean_loss - std_loss,\n", - " # mean_loss + std_loss,\n", - " # color=colors[i],\n", - " # alpha=0.2,\n", - " # )\n", - "\n", - " plt.xlabel(\"Training Steps\", fontsize=14)\n", - " plt.ylabel(\"Training Loss\", fontsize=14)\n", - " plt.title(\n", - " f\"Training Loss Curves: Group Size $|G|$ vs Steps\\n($k={k}$, $H={hidden_dim}$)\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.legend(loc=\"best\", fontsize=10, ncol=2)\n", - " plt.grid(True, alpha=0.3, which=\"both\")\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96be1620", - "metadata": {}, - "outputs": [], - "source": [ - "# Load final validation loss data for each k value separately\n", - "max_p = 60 # Only visualize completed experiments (p <= 55)\n", - "\n", - "from matplotlib.colors import LogNorm\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " loss_grid, loss_std = load_sweep_results_grid_final_val_loss_p_h(\n", - " sweep_dir,\n", - " k,\n", - " p_values,\n", - " hidden_dims,\n", - " max_p=max_p,\n", - " )\n", - "\n", - " # Filter p values - only show p <= max_p\n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Plot final validation loss heatmap: p (group size) vs hidden_dim\n", - " plt.figure(figsize=(12, 8))\n", - " cmap = plt.cm.viridis_r.copy()\n", - " cmap.set_bad(color=\"black\")\n", - " # Set extent to align cells with tick positions\n", - " # extent: [left, right, bottom, top] in data coordinates\n", - " plt.imshow(\n", - " loss_grid[:, : len(p_values_filtered)],\n", - " aspect=\"equal\",\n", - " cmap=cmap,\n", - " norm=LogNorm(),\n", - " )\n", - "\n", - " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - " plt.xticks(\n", - " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n", - " )\n", - "\n", - " # Set y-axis ticks (hidden dimensions)\n", - " plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - " plt.gca().invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - "\n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= upper_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1)\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " # Find the first H that satisfies H >= lower_boundary_coeff * p\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"blue\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " # Place legend outside the plot area\n", - " plt.legend(\n", - " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n", - " )\n", - "\n", - " plt.colorbar(label=\"Final Validation Loss\")\n", - " plt.title(\n", - " f\"Final Validation Loss: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = incomplete experiment, p ≤ {max_p})\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "d3111eeb", - "metadata": {}, - "source": [ - "### Training Instability" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d743a392", - "metadata": {}, - "outputs": [], - "source": [ - "def load_sweep_results_grid_spikiness_p_h(\n", - " sweep_dir: str, k: int, p_values: list, hidden_dims: list, max_p: int = None\n", - "):\n", - " \"\"\"\n", - " Compute fraction of training steps where loss increased (instability) for p vs h sweeps.\n", - "\n", - " Updated for experiment naming: k{k}_p{p}_h{h}\n", - " Only loads completed experiments (checks for run_summary.yaml).\n", - "\n", - " Args:\n", - " sweep_dir: Path to the sweep directory\n", - " k: Sequence length parameter (2, 3, or 4)\n", - " p_values: List of p (group size) values\n", - " hidden_dims: List of hidden dimension values\n", - " max_p: Maximum p value to include (filters incomplete experiments)\n", - "\n", - " Returns:\n", - " grid: 2D array with mean frac_upward across seeds\n", - " Shape: (len(hidden_dims), len(p_values))\n", - " std_grid: 2D array with standard deviations\n", - " \"\"\"\n", - " sweep_path = Path(sweep_dir)\n", - "\n", - " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n", - "\n", - " for i, h in enumerate(hidden_dims):\n", - " for j, p in enumerate(p_values):\n", - " # Filter by max_p if specified\n", - " if max_p is not None and p > max_p:\n", - " continue\n", - "\n", - " exp_name = f\"k{k}_p{p}_h{h}\"\n", - " exp_dir = sweep_path / exp_name\n", - "\n", - " if not exp_dir.exists():\n", - " continue\n", - "\n", - " # Check if experiment is completed\n", - " seed_dir = exp_dir / \"seed_0\"\n", - " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n", - " continue # Skip incomplete experiments\n", - "\n", - " frac_upwards = []\n", - " for seed_dir in exp_dir.glob(\"seed_*\"):\n", - " loss_file = seed_dir / \"train_loss_history.npy\"\n", - " if loss_file.exists():\n", - " loss_history = np.load(loss_file)\n", - " log_loss = np.log10(loss_history + 1e-10)\n", - " log_changes = np.diff(log_loss)\n", - "\n", - " # Fraction of steps where loss went UP\n", - " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n", - " frac_upwards.append(frac_upward)\n", - "\n", - " if frac_upwards:\n", - " grid[i, j] = np.mean(frac_upwards)\n", - " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n", - "\n", - " return grid, std_grid" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "683b555c", - "metadata": {}, - "outputs": [], - "source": [ - "# Load spikiness data for each k value separately\n", - "max_p = 70 # Only visualize completed experiments\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n", - " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n", - " )\n", - "\n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Plot\n", - " plt.figure(figsize=(12, 8))\n", - " # Set extent to align cells with tick positions\n", - " plt.imshow(\n", - " spike_grid_p[:, : len(p_values_filtered)],\n", - " aspect=\"equal\",\n", - " cmap=\"plasma\",\n", - " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5],\n", - " )\n", - " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - " plt.xticks(\n", - " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n", - " )\n", - " plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - " plt.gca().invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - "\n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1)\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ = $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"white\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ = $2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " plt.legend(\n", - " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n", - " )\n", - "\n", - " plt.colorbar(label=\"Fraction of Upward Steps (Spikiness)\")\n", - " plt.title(\n", - " f\"Training Instability: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$)\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "772517ad", - "metadata": {}, - "outputs": [], - "source": [ - "# Load both metrics for each k value separately\n", - "reduction_threshold = 0.99\n", - "spikiness_threshold = 0.1\n", - "max_p = 55 # Only visualize completed experiments\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " conv_grid_p, conv_std_p = load_sweep_results_grid_convergence_p_h(\n", - " sweep_dir, k, p_values, hidden_dims, \n", - " reduction_threshold=reduction_threshold,\n", - " max_p=max_p\n", - " )\n", - " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n", - " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n", - " )\n", - " \n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Create categorical grid: 0=black (no conv), 1=purple (spiky), 2=yellow (smooth)\n", - " category_grid = np.full((len(hidden_dims), len(p_values_filtered)), 0.0) # Start with 0 (black)\n", - "\n", - " for i in range(len(hidden_dims)):\n", - " for j in range(len(p_values_filtered)):\n", - " converged = not np.isnan(conv_grid_p[i, j])\n", - "\n", - " if converged:\n", - " spiky = spike_grid_p[i, j] > spikiness_threshold\n", - " if spiky:\n", - " category_grid[i, j] = 1.0 # Purple (spiky)\n", - " else:\n", - " category_grid[i, j] = 2.0 # Yellow (smooth)\n", - " # else stays 0.0 (black, did not converge)\n", - "\n", - " # Plot\n", - " fig, ax = plt.subplots(figsize=(12, 8))\n", - "\n", - " # Custom colormap: black -> purple -> yellow\n", - " from matplotlib.colors import ListedColormap\n", - "\n", - " colors = [\"black\", \"purple\", \"yellow\"]\n", - " cmap = ListedColormap(colors)\n", - "\n", - " # Set extent to align cells with tick positions\n", - " im = ax.imshow(\n", - " category_grid, \n", - " aspect=\"auto\", \n", - " cmap=cmap, \n", - " vmin=0, \n", - " vmax=2,\n", - " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5]\n", - " )\n", - "\n", - " ax.set_xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " ax.set_ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - "\n", - " # Set x-axis ticks (p values)\n", - " ax.set_xticks(range(len(p_values_filtered)))\n", - " ax.set_xticklabels(p_values_filtered, rotation=45, ha=\"center\")\n", - "\n", - " # Set y-axis ticks (hidden dimensions)\n", - " ax.set_yticks(range(len(hidden_dims)))\n", - " ax.set_yticklabels(hidden_dims)\n", - " ax.invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - " \n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1)\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " ax.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " )\n", - " \n", - " ax.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"blue\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " # Create custom legend\n", - " from matplotlib.patches import Patch\n", - "\n", - " legend_elements = [\n", - " Patch(facecolor=\"black\", label=\"Did not converge\"),\n", - " Patch(facecolor=\"purple\", label=f\"Spiky (frac_up > {spikiness_threshold})\"),\n", - " Patch(facecolor=\"yellow\", label=\"Smooth convergence\"),\n", - " plt.Line2D([0], [0], color=\"r\", linewidth=3, linestyle=\"--\", label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\"),\n", - " plt.Line2D([0], [0], color=\"b\", linewidth=3, linestyle=\"--\", label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\"),\n", - " ]\n", - "\n", - " ax.legend(handles=legend_elements, loc=\"upper left\", fontsize=11, frameon=True)\n", - "\n", - " ax.set_title(\n", - " f\"Convergence & Spikiness: $|G|$ vs $H$ ($k={k}$)\\nThresholds: {reduction_threshold*100}% convergence, {spikiness_threshold} spikiness (p ≤ {max_p})\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "53ed6f4f", - "metadata": {}, - "outputs": [], - "source": [ - "# Load both convergence and spikiness data for each k value separately\n", - "reduction_threshold = 0.99\n", - "max_p = 55 # Only visualize completed experiments\n", - "\n", - "# Create separate plots for each k value\n", - "for k in k_values:\n", - " conv_grid_p, conv_std_p = load_sweep_results_grid_convergence_p_h(\n", - " sweep_dir,\n", - " k,\n", - " p_values,\n", - " hidden_dims,\n", - " reduction_threshold=reduction_threshold,\n", - " max_p=max_p,\n", - " )\n", - " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n", - " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n", - " )\n", - "\n", - " p_values_filtered = [p for p in p_values if p <= max_p]\n", - "\n", - " # Mask spikiness grid: only show spikiness for converged runs\n", - " spike_grid_masked = spike_grid_p.copy()\n", - " for i in range(len(hidden_dims)):\n", - " for j in range(len(p_values_filtered)):\n", - " if np.isnan(conv_grid_p[i, j]):\n", - " # Did not converge - set to NaN (will be black)\n", - " spike_grid_masked[i, j] = np.nan\n", - "\n", - " # Plot with masked spikiness\n", - " plt.figure(figsize=(12, 8))\n", - "\n", - " # Use colormap with black for NaN\n", - " cmap_spike = plt.cm.plasma.copy()\n", - " cmap_spike.set_bad(color=\"black\")\n", - "\n", - " # Set extent to align cells with tick positions\n", - " plt.imshow(\n", - " spike_grid_masked[:, : len(p_values_filtered)],\n", - " aspect=\"auto\",\n", - " cmap=cmap_spike,\n", - " vmin=0,\n", - " vmax=0.5,\n", - " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5],\n", - " )\n", - " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n", - " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n", - " plt.xticks(\n", - " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n", - " )\n", - " plt.yticks(range(len(hidden_dims)), hidden_dims)\n", - " plt.gca().invert_yaxis()\n", - "\n", - " # Theory boundaries\n", - " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n", - "\n", - " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n", - " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n", - " y_step_upper = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n", - " if upper_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_upper.append(y_step_upper[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_upper = [y - 0.5 for y in y_step_upper]\n", - "\n", - " # Lower boundary: H = 2^{k-1} * |G|\n", - " lower_boundary_coeff = 2 ** (k - 1)\n", - " y_step_lower = [\n", - " min(\n", - " len(hidden_dims) - 1,\n", - " (\n", - " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n", - " if lower_boundary_coeff * p <= max(hidden_dims)\n", - " else len(hidden_dims) - 1\n", - " ),\n", - " )\n", - " for p in p_values_filtered\n", - " ]\n", - " y_step_lower.append(y_step_lower[-1])\n", - " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n", - " y_step_lower = [y - 0.5 for y in y_step_lower]\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_upper,\n", - " where=\"post\",\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " plt.step(\n", - " x_step,\n", - " y_step_lower,\n", - " where=\"post\",\n", - " color=\"blue\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n", - " )\n", - "\n", - " # Custom legend\n", - " from matplotlib.patches import Patch\n", - "\n", - " legend_elements = [\n", - " Patch(facecolor=\"black\", label=\"Did not converge\"),\n", - " Patch(facecolor=\"purple\", label=\"Low spikiness (~0)\"),\n", - " Patch(facecolor=\"yellow\", label=\"High spikiness (~0.5)\"),\n", - " plt.Line2D(\n", - " [0],\n", - " [0],\n", - " color=\"red\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n", - " ),\n", - " plt.Line2D(\n", - " [0],\n", - " [0],\n", - " color=\"white\",\n", - " linewidth=3,\n", - " linestyle=\"--\",\n", - " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n", - " ),\n", - " ]\n", - " plt.legend(\n", - " handles=legend_elements,\n", - " loc=\"upper center\",\n", - " bbox_to_anchor=(0.5, -0.12),\n", - " fontsize=11,\n", - " frameon=True,\n", - " ncol=5,\n", - " )\n", - "\n", - " plt.colorbar(label=\"Fraction of Upward Steps (Spikiness)\")\n", - " plt.title(\n", - " f\"Training Instability: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = did not converge, p ≤ {max_p})\",\n", - " fontsize=14,\n", - " fontweight=\"bold\",\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f30eb24", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/sequential_cnxcn.ipynb b/notebooks/sequential_cnxcn.ipynb new file mode 100644 index 0000000..f957123 --- /dev/null +++ b/notebooks/sequential_cnxcn.ipynb @@ -0,0 +1,287 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sequential Group Composition on $C_n \\times C_n$\n", + "\n", + "**Group:** Product of cyclic groups $C_n \\times C_n$ of order $n^2$. \n", + "**Task:** Given a sequence of $k$ group elements $g_1, \\ldots, g_k \\in C_n \\times C_n$, predict their cumulative product. \n", + "**Sequence length:** $k = 3$ (sequential composition). \n", + "**Architecture:** `QuadraticRNN` with quadratic recurrence. \n", + "**Key result:** The RNN composes elements sequentially in $k$ steps, exploiting associativity." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader, TensorDataset\n", + "\n", + "import src.dataset as dataset\n", + "import src.model as model\n", + "import src.optimizer as optimizer\n", + "import src.power as power\n", + "import src.template as template\n", + "import src.train as train_mod\n", + "import src.viz as viz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", + "\n", + "seed = 0\n", + "random.seed(seed)\n", + "np.random.seed(seed)\n", + "torch.manual_seed(seed)\n", + "\n", + "p1, p2 = 7, 7 # Cn x Cn dimensions\n", + "p_flat = p1 * p2\n", + "k = 3 # Sequence length\n", + "hidden_dim = 50 if TEST_MODE else 200\n", + "epochs = 2 if TEST_MODE else 5\n", + "num_samples = 100 if TEST_MODE else 10000\n", + "batch_size = 64 if TEST_MODE else 1000\n", + "lr = 1e-3\n", + "init_scale = 1e-2\n", + "mode = \"sampled\"\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "FIGURES_DIR = \"figures\"\n", + "os.makedirs(FIGURES_DIR, exist_ok=True)\n", + "\n", + "print(f\"Group: C_{p1} x C_{p2}, order {p_flat}\")\n", + "print(f\"Sequence length: k={k}\")\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a 2D template with known Fourier structure\n", + "tpl_2d = template.unique_freqs_2d(p1, p2, n_freqs=5, seed=seed)\n", + "\n", + "# Build sequential dataset\n", + "X, Y, sequence_xy = dataset.build_modular_addition_sequence_dataset_2d(\n", + " p1, p2, tpl_2d, k, mode=mode, num_samples=num_samples,\n", + ")\n", + "\n", + "X_tensor = torch.tensor(X, dtype=torch.float32).to(device)\n", + "Y_tensor = torch.tensor(Y, dtype=torch.float32).to(device)\n", + "\n", + "ds = TensorDataset(X_tensor, Y_tensor)\n", + "dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)\n", + "\n", + "print(f\"Dataset: {len(ds)} samples\")\n", + "print(f\"X shape: {X_tensor.shape} (N, k, p1*p2)\")\n", + "print(f\"Y shape: {Y_tensor.shape} (N, p1*p2)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize template\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n", + "\n", + "im = ax1.imshow(tpl_2d, cmap=\"RdBu_r\")\n", + "ax1.set_title(f\"Template on $C_{{{p1}}} \\\\times C_{{{p2}}}$\")\n", + "ax1.set_xlabel(\"$C_{\" + str(p2) + \"}$ index\")\n", + "ax1.set_ylabel(\"$C_{\" + str(p1) + \"}$ index\")\n", + "plt.colorbar(im, ax=ax1)\n", + "\n", + "# Show 2D power spectrum\n", + "pwr_2d = power.get_power_2d(tpl_2d, no_freq=True)\n", + "im2 = ax2.imshow(np.log10(pwr_2d + 1e-12), cmap=\"hot\")\n", + "ax2.set_title(\"Log power spectrum\")\n", + "ax2.set_xlabel(\"Freq $k_2$\")\n", + "ax2.set_ylabel(\"Freq $k_1$\")\n", + "plt.colorbar(im2, ax=ax2)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_template.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model and Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = model.QuadraticRNN(\n", + " p=p_flat,\n", + " d=hidden_dim,\n", + " template=tpl_2d.flatten(),\n", + " init_scale=init_scale,\n", + ")\n", + "net = net.to(device)\n", + "\n", + "criterion = nn.MSELoss()\n", + "opt = optimizer.HybridRNNOptimizer(net, lr=lr)\n", + "\n", + "print(f\"Model: QuadraticRNN(p={p_flat}, d={hidden_dim}, k={k})\")\n", + "print(f\"Optimizer: HybridRNNOptimizer(lr={lr})\")\n", + "print(f\"Training for {epochs} epochs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n", + " net,\n", + " dataloader,\n", + " criterion,\n", + " opt,\n", + " epochs=epochs,\n", + " verbose_interval=1,\n", + " save_param_interval=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "\n", + "# Plot training loss with theoretical levels\n", + "ax.plot(loss_history, lw=4, label=\"Train loss\")\n", + "\n", + "theory = power.theoretical_loss_levels_2d(tpl_2d)\n", + "for level in theory[\"levels\"]:\n", + " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=1.5, zorder=-2)\n", + "\n", + "ax.set_xlabel(\"Epochs\", fontsize=18)\n", + "ax.set_ylabel(\"Train Loss\", fontsize=18)\n", + "ax.set_title(f\"Training loss (sequential $k={k}$ on $C_{{{p1}}} \\\\times C_{{{p2}}}$)\", fontsize=18)\n", + "viz.style_axes(ax)\n", + "ax.grid(False)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_loss.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show predictions vs ground truth\n", + "net.load_state_dict(param_history[-1])\n", + "net.eval()\n", + "\n", + "n_examples = 3\n", + "indices = np.random.choice(len(Y_tensor), size=n_examples, replace=False)\n", + "\n", + "fig, axes = plt.subplots(n_examples, 2, figsize=(8, 3 * n_examples))\n", + "\n", + "with torch.no_grad():\n", + " x_batch = X_tensor[indices]\n", + " preds = net(x_batch).detach().cpu().numpy()\n", + " truths = Y_tensor[indices].detach().cpu().numpy()\n", + "\n", + "for i in range(n_examples):\n", + " axes[i, 0].imshow(truths[i].reshape(p1, p2), cmap=\"RdBu_r\")\n", + " axes[i, 0].set_title(\"Ground truth\")\n", + " axes[i, 1].imshow(preds[i].reshape(p1, p2), cmap=\"RdBu_r\")\n", + " axes[i, 1].set_title(\"Prediction\")\n", + "\n", + "plt.suptitle(f\"Predictions (sequential $k={k}$ on $C_{{{p1}}} \\\\times C_{{{p2}}}$, epoch {final_epoch})\", fontsize=16)\n", + "plt.tight_layout()\n", + "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_predictions.pdf\", bbox_inches=\"tight\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "group-agf", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/template_bar.pdf b/notebooks/template_bar.pdf deleted file mode 100644 index 63f21bf..0000000 Binary files a/notebooks/template_bar.pdf and /dev/null differ diff --git a/notebooks/template_fft_bar.pdf b/notebooks/template_fft_bar.pdf deleted file mode 100644 index 808eaf6..0000000 Binary files a/notebooks/template_fft_bar.pdf and /dev/null differ diff --git a/notebooks/znz_znz.ipynb b/notebooks/znz_znz.ipynb deleted file mode 100644 index 9e38ce6..0000000 --- a/notebooks/znz_znz.ipynb +++ /dev/null @@ -1,234 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "51d11caf-0971-4324-b63b-819b714a9c3c", - "metadata": {}, - "source": [ - "# Learning Z/nZ x Z/nZ group actions\n", - "This notebook is adapted from the `modular arithmetic` notebook, replacing `Z/nZ` group action with `Z/nZ x Z/nZ` group action " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import random\n", - "import torch\n", - "import os\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "import shutil\n", - "from torch.utils.data import DataLoader, TensorDataset\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.cm as cm\n", - "from matplotlib.animation import FuncAnimation\n", - "from matplotlib.ticker import FormatStrFormatter\n", - "from matplotlib.ticker import FuncFormatter\n", - "from matplotlib.ticker import MaxNLocator\n", - "\n", - "import importlib\n", - "import pickle\n", - "\n", - "import group_agf.binary_action_learning.models as models\n", - "import group_agf.binary_action_learning.datasets as datasets\n", - "import group_agf.binary_action_learning.power as power\n", - "import group_agf.binary_action_learning.train as train\n", - "import group_agf.binary_action_learning.plot as plot\n" - ] - }, - { - "cell_type": "markdown", - "id": "da0f43df", - "metadata": {}, - "source": [ - "# Define Dataset and Visualize" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f5dba48b", - "metadata": {}, - "outputs": [], - "source": [ - "from group_agf.binary_action_learning.default_config import verbose_interval\n", - "import os\n", - "\n", - "# TEST_MODE: Set to reduce epochs for automated testing\n", - "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n", - "\n", - "p = 3 if TEST_MODE else 5 # Reduced in test mode\n", - "mnist_digit = 4\n", - "dataset_fraction = 0.1 if TEST_MODE else 0.2 # Reduced in test mode\n", - "template_type = 'mnist'\n", - "seed = 47\n", - "batch_size = 32 if TEST_MODE else 128 # Reduced in test mode\n", - "hidden_size = 32 if TEST_MODE else 128 # Reduced in test mode\n", - "lr = 0.001\n", - "mom = 0.9\n", - "init_scale = 1e-2\n", - "epochs = 2 if TEST_MODE else 1000\n", - "verbose_interval = max(1, epochs // 10)\n", - "\n", - "model_save_path = (\n", - " f\"/tmp/adele/model_\"\n", - " f\"p{p}_\"\n", - " f\"digit{mnist_digit}_\"\n", - " f\"frac{dataset_fraction}_\"\n", - " f\"type{template_type}_\"\n", - " f\"seed{seed}.pkl\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "615fe334", - "metadata": {}, - "outputs": [], - "source": [ - "template = datasets.choose_template(p, template_type, mnist_digit)\n", - "group = 'cnxcn'\n", - "\n", - "top_frequency_plot = plot.plot_top_template_components(group, template, p)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a28838b7", - "metadata": {}, - "outputs": [], - "source": [ - "X, Y, translations = datasets.load_modular_addition_dataset_2d(p, template, fraction=dataset_fraction, random_state=seed, template_type=template_type)\n", - "\n", - "X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, p, device=None)\n", - "\n", - "dataset = TensorDataset(X, Y)\n", - "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)" - ] - }, - { - "cell_type": "markdown", - "id": "111a2530", - "metadata": {}, - "source": [ - "# Define Model and Train" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00b56ec1", - "metadata": {}, - "outputs": [], - "source": [ - "np.random.seed(seed)\n", - "torch.manual_seed(seed)\n", - "torch.cuda.manual_seed_all(seed) # if using GPU\n", - "\n", - "model = models.TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=init_scale, output_scale=1e0)\n", - "model = model.to(device)\n", - "loss = nn.MSELoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(mom, 0.999))\n", - "\n", - "loss_history, accuracy_history, param_history = train.train(\n", - " model,\n", - " dataloader,\n", - " loss,\n", - " optimizer,\n", - " epochs=epochs,\n", - " verbose_interval=verbose_interval,\n", - " model_save_path=model_save_path\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "e3b19dcc", - "metadata": {}, - "source": [ - "# Plot loss, power, and model output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5989bcd9", - "metadata": {}, - "outputs": [], - "source": [ - "loss_plot = plot.plot_loss_curve(loss_history, template)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1965705", - "metadata": {}, - "outputs": [], - "source": [ - "template_2d = template.reshape((p, p))\n", - "power_over_training_plot = plot.plot_training_power_over_time(template_2d, model, device, param_history, X, p, save_path=None, show=False) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0319335", - "metadata": {}, - "outputs": [], - "source": [ - "neuron_indices = list(range(20))\n", - "group= 'cnxcn'\n", - "print(neuron_indices)\n", - "neuron_weights_plot = plot.plot_neuron_weights(group, model, p, neuron_indices=neuron_indices, show=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80fc35d9", - "metadata": {}, - "outputs": [], - "source": [ - "idx = 13\n", - "plot.plot_model_outputs(p, model, X, Y, idx)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "296b3d56", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "group-agf", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/test/test_notebooks.py b/test/test_notebooks.py index a87daf6..9352ebb 100644 --- a/test/test_notebooks.py +++ b/test/test_notebooks.py @@ -22,6 +22,7 @@ import os import subprocess import sys +import tempfile from pathlib import Path import pytest @@ -39,19 +40,8 @@ def get_notebooks_dir(): # Notebooks to skip (with reasons) SKIP_NOTEBOOKS = { - # These notebooks have hardcoded paths to /home/facosta/ which don't exist - "seq_mlp_group_size": "Has hardcoded paths to /home/facosta/ filesystem", - "rnn_gagf": "Has hardcoded paths to /home/facosta/ filesystem", - # These notebooks require pre-trained model files or external data - "paper_figures": "Requires pre-trained model .pkl files not included in repo", - # These notebooks have import/code issues that need separate debugging - "2D": "Missing function: cannot import 'get_power_2d' from src.power", - "znz_znz": "Missing function: datasets.choose_template() does not exist", - "seq_mlp": "Plotting error: Invalid vmin/vmax values during visualization", - # These notebooks have visualization code with hardcoded indices that fail with reduced p - "C_n": "IndexError in visualization code when running with reduced parameters", - "dihedral": "IndexError in visualization code when running with reduced parameters", - "modular_arithmetic": "IndexError in visualization code when running with reduced parameters", + # Add notebooks here if they need to be skipped, e.g.: + # "notebook_name": "Reason for skipping", } @@ -89,32 +79,34 @@ def execute_notebook(notebook_path, env): tuple: (success: bool, error_message: str or None) """ try: - result = subprocess.run( - [ - sys.executable, - "-m", - "jupyter", - "nbconvert", - "--to", - "notebook", - "--execute", - "--ExecutePreprocessor.timeout=300", # 5 minute timeout per notebook - "--ExecutePreprocessor.kernel_name=python3", - "--output", - "/dev/null", - str(notebook_path), - ], - capture_output=True, - text=True, - env=env, - cwd=str(get_repo_root()), - timeout=360, # 6 minute overall timeout - ) - - if result.returncode != 0: - error_msg = f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" - return False, error_msg - return True, None + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "output.ipynb") + result = subprocess.run( + [ + sys.executable, + "-m", + "jupyter", + "nbconvert", + "--to", + "notebook", + "--execute", + "--ExecutePreprocessor.timeout=300", # 5 minute timeout per notebook + "--ExecutePreprocessor.kernel_name=python3", + "--output", + output_path, + str(notebook_path), + ], + capture_output=True, + text=True, + env=env, + cwd=str(get_repo_root()), + timeout=360, # 6 minute overall timeout + ) + + if result.returncode != 0: + error_msg = f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}" + return False, error_msg + return True, None except subprocess.TimeoutExpired: return False, "Notebook execution timed out (>6 minutes)"