diff --git a/.gitignore b/.gitignore
index 509187e..791ddd9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ training_history.pkl
fourier_power_only.pdf
notebooks/figs/
notebooks/figs_temp/
+notebooks/figures/
notebooks/saved_models/
notebooks/checkpoint*
notebooks/W-weights.pdf
diff --git a/notebooks/2D.ipynb b/notebooks/2D.ipynb
deleted file mode 100644
index eac8f90..0000000
--- a/notebooks/2D.ipynb
+++ /dev/null
@@ -1,1701 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 2D Sequential modular addition with a Quadratic RNN\n",
- "\n",
- "To do:\n",
- "- Should have an option to compute loss over all intermediate compositions (this will allow us to train for longer)\n",
- "- Should create a dataloader that samples online\n",
- "- Better activation plots\n"
- ],
- "id": "14ad894a-55e4-4f86-9e22-b5260504966b"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Set up"
- ],
- "id": "6d1d8a65"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# autoreload\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "# jupyter black formatter\n",
- "%load_ext jupyter_black\n",
- "\n",
- "import subprocess\n",
- "import os\n",
- "import sys\n",
- "\n",
- "gitroot_path = subprocess.check_output(\n",
- " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n",
- ")\n",
- "\n",
- "os.chdir(os.path.join(gitroot_path[:-1], \"gagf\"))\n",
- "print(\"Working directory: \", os.getcwd())\n",
- "\n",
- "sys_dir = os.path.dirname(os.getcwd())\n",
- "sys.path.append(sys_dir)\n",
- "print(\"Directory added to path: \", sys_dir)\n",
- "sys.path.append(os.getcwd())\n",
- "print(\"Directory added to path: \", os.getcwd())"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "fd71a9a8"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Imports"
- ],
- "id": "62884979"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# Core\n",
- "import numpy as np\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "\n",
- "# Torch utilities\n",
- "import torch.optim as optim\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "\n",
- "# Vision\n",
- "import torchvision\n",
- "from torchvision import transforms\n",
- "\n",
- "# Plotting\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.gridspec as gridspec\n",
- "from matplotlib.patches import Rectangle, Patch\n",
- "from matplotlib.ticker import MaxNLocator\n",
- "from matplotlib.lines import Line2D\n",
- "from matplotlib.colors import Normalize\n",
- "from matplotlib.colors import LogNorm\n",
- "import matplotlib.cm as cm\n",
- "\n",
- "# Misc\n",
- "from tqdm import tqdm\n",
- "from typing import Optional"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## RNN Architecture"
- ],
- "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.model import QuadraticRNN"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "5f63c4dd"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Optimization"
- ],
- "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.train import train"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "1035f81c-e877-4655-8640-4e4c3d323af8"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Plotting functions"
- ],
- "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.utils import (\n",
- " style_axes,\n",
- " get_power_2d_adele,\n",
- " topk_template_freqs,\n",
- ")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2D Dataset"
- ],
- "id": "54bef8ae"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.datamodule import (\n",
- " build_modular_addition_sequence_dataset_2d,\n",
- " mnist_template_2d,\n",
- " generate_template_unique_freqs,\n",
- ")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "2aa4c3fd"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Experiment Setup"
- ],
- "id": "f1c0453e"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.utils import plot_2d_power_spectrum, plot_2d_signal, get_power_2d\n",
- "\n",
- "# Set seed\n",
- "np.random.seed(5)\n",
- "torch.manual_seed(5)\n",
- "\n",
- "# 2D dimensions\n",
- "p1, p2 = 28, 28 # Can be different, but start with square\n",
- "p_flat = p1 * p2\n",
- "\n",
- "# Generate 2D template\n",
- "template_2d = generate_template_unique_freqs(p1, p2, n_freqs=10)\n",
- "\n",
- "# Mean center template\n",
- "template_2d = template_2d - np.mean(template_2d)\n",
- "\n",
- "# Visualize template and its spectrum\n",
- "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
- "plot_2d_signal(axes[0], template_2d, title=\"2D Template\", cmap=\"grey\")\n",
- "power_2d, fx, fy = get_power_2d(template_2d)\n",
- "plot_2d_power_spectrum(axes[1], power_2d, fx, fy, title=\"Template Power Spectrum\")\n",
- "plt.tight_layout()\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "d424edf7-ad00-4836-a57c-6de2b2e06a63"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# Build sequence data\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "k = 3 # sequence length\n",
- "mode = \"sampled\"\n",
- "# TEST_MODE: Reduce num_samples for faster automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "num_samples = 1000 if TEST_MODE else 100000 # Reduced in test mode\n",
- "\n",
- "X_seq_2d, Y_seq_2d, sequence_xy = build_modular_addition_sequence_dataset_2d(\n",
- " p1, p2, template_2d, k, mode=mode, num_samples=num_samples\n",
- ")\n",
- "\n",
- "# Convert to torch tensors\n",
- "X_seq_2d_t = torch.tensor(X_seq_2d, dtype=torch.float32, device=device)\n",
- "Y_seq_2d_t = torch.tensor(Y_seq_2d, dtype=torch.float32, device=device)\n",
- "\n",
- "print(f\"Dataset shapes:\")\n",
- "print(f\" X: {X_seq_2d_t.shape} (N, k, p1*p2)\")\n",
- "print(f\" Y: {Y_seq_2d_t.shape} (N, p1*p2)\")\n",
- "print(f\" Flattened dimension: {p_flat}\")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "1a643405"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Train 2D model"
- ],
- "id": "624d77d9"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.main import main"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "4b53c4ac"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "config = {\n",
- " \"data\": {\n",
- " \"p1\": p1,\n",
- " \"p2\": p2,\n",
- " \"k\": k,\n",
- " \"mode\": mode,\n",
- " \"num_samples\": num_samples,\n",
- " \"n_freqs\": 10,\n",
- " \"batch_size\": 1000,\n",
- " \"seed\": 5,\n",
- " },\n",
- " \"model\": {\n",
- " \"hidden_dim\": 36,\n",
- " \"init_scale\": 1.0e-5,\n",
- " },\n",
- " \"training\": {\n",
- " \"epochs\": 5,\n",
- " \"learning_rate\": 0.0001,\n",
- " \"weight_decay\": 0.0,\n",
- " \"betas\": [0.9, 0.999],\n",
- " \"grad_clip\": 0.1,\n",
- " \"verbose_interval\": 5,\n",
- " },\n",
- " \"device\": \"cuda\",\n",
- "}"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "474b1423"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "main(config)"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "38b78194"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from gagf.rnns.train import train\n",
- "\n",
- "# TEST_MODE: Reduce batch_size and hidden_dim for faster automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "batch_size = 100 if TEST_MODE else 1000 # Reduced in test mode\n",
- "seq_dataset_2d = TensorDataset(X_seq_2d_t, Y_seq_2d_t)\n",
- "seq_loader_2d = DataLoader(\n",
- " seq_dataset_2d, batch_size=batch_size, shuffle=True, drop_last=False\n",
- ")\n",
- "\n",
- "# Model - note p is now p1*p2\n",
- "template_torch = torch.tensor(template_2d, device=device, dtype=torch.float32).flatten()\n",
- "hidden_dim_2d = 12 if TEST_MODE else 36 # Reduced in test mode\n",
- "rnn_2d = QuadraticRNN(\n",
- " p=p_flat,\n",
- " d=hidden_dim_2d,\n",
- " template=template_torch,\n",
- " init_scale=1e-5,\n",
- ").to(device)\n",
- "\n",
- "criterion = nn.MSELoss()\n",
- "\n",
- "weight_decay = 0 # relevant for W_mix plot\n",
- "learning_rate = 0.0001\n",
- "\n",
- "optimizer = optim.Adam(\n",
- " rnn_2d.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay\n",
- ")\n",
- "\n",
- "# TEST_MODE: Set to reduce epochs for automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "epochs = 2 if TEST_MODE else 5\n",
- "\n",
- "loss_hist_2d, acc_hist_2d, param_hist_2d = train(\n",
- " rnn_2d,\n",
- " seq_loader_2d,\n",
- " criterion,\n",
- " optimizer,\n",
- " epochs=epochs,\n",
- " verbose_interval=max(1, epochs),\n",
- " grad_clip=0.1,\n",
- ")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "31005b99"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Loss Plot"
- ],
- "id": "aaf7af30-5f5e-4ea5-8425-1da1efc1fa69"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# --- Plot RNN training loss only ---\n",
- "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n",
- "\n",
- "# x-axis = epochs 1..T\n",
- "x = np.arange(1, len(loss_hist_2d) + 1)\n",
- "ax.plot(x, loss_hist_2d, lw=4)\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "x_freq, y_freq, power = get_power_2d_adele(template_2d)\n",
- "power = power.flatten()\n",
- "valid = power > 1e-20\n",
- "power = power[valid]\n",
- "sorted_idx = np.argsort(power)[::-1] # np.argsort with [::-1] gives descending order\n",
- "power = power[sorted_idx]\n",
- "# print(\"Power in x: {}\".format(power))\n",
- "\n",
- "# Plot theoretical lines\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / (p1 * p2)\n",
- "for k, alpha in enumerate(alpha_values):\n",
- " ax.axhline(y=coef * alpha, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
- "\n",
- "\n",
- "# # Steps\n",
- "# steps = [1, 20, 50, len(param_hist) - 1] # make sure these are < len(rnn_param_hist)\n",
- "# for step in steps:\n",
- "# ax.axvline(step, c=\"k\")\n",
- "\n",
- "# plot in log scale\n",
- "# ax.set_xscale(\"log\")\n",
- "# ax.set_yscale(\"log\")\n",
- "# ax.set_ylim(3.5e-2, 1.3e-1)\n",
- "\n",
- "ax.set_xlabel(\"Epochs\", fontsize=24)\n",
- "ax.set_ylabel(\"Train Loss\", fontsize=24)\n",
- "\n",
- "style_axes(ax)\n",
- "\n",
- "plt.grid(False)\n",
- "plt.tight_layout()\n",
- "# plt.savefig(\"rnn-loss.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "d1037a02"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Prediction Plot"
- ],
- "id": "d4f01b69"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import torch\n",
- "\n",
- "# --- config ---\n",
- "steps_2d = [1, 5] # , 10, len(param_hist_2d) - 1]\n",
- "example_idx = int(np.random.randint(len(Y_seq_2d_t)))\n",
- "cmap = \"gray\"\n",
- "origin = \"upper\" # use \"lower\" if you prefer cartesian orientation\n",
- "\n",
- "device = next(rnn_2d.parameters()).device\n",
- "rnn_2d.to(device).eval()\n",
- "\n",
- "# --- collect preds for the SAME example across steps ---\n",
- "preds = []\n",
- "with torch.no_grad():\n",
- " truth_2d = Y_seq_2d_t[example_idx].reshape(p1, p2).cpu().numpy()\n",
- "\n",
- "for step in steps_2d:\n",
- " rnn_2d.load_state_dict(param_hist_2d[step], strict=True)\n",
- " with torch.no_grad():\n",
- " x = X_seq_2d_t[example_idx : example_idx + 1].to(device) # (1, k, p1*p2)\n",
- " pred_2d = rnn_2d(x).reshape(p1, p2).detach().cpu().numpy()\n",
- " preds.append(pred_2d)\n",
- "\n",
- "# --- shared color scale ---\n",
- "vmin = np.min(truth_2d) # min(np.min(truth_2d), *(np.min(p) for p in preds))\n",
- "vmax = np.max(truth_2d) # max(np.max(truth_2d), *(np.max(p) for p in preds))\n",
- "\n",
- "# --- plot: rows = [Prediction, Target], cols = time steps ---\n",
- "fig, axes = plt.subplots(\n",
- " 2, len(steps_2d), figsize=(3.5 * len(steps_2d), 6), layout=\"constrained\"\n",
- ")\n",
- "\n",
- "for col, (step, pred_2d) in enumerate(zip(steps_2d, preds)):\n",
- " im = axes[0, col].imshow(pred_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin=origin)\n",
- " axes[0, col].set_title(f\"t = {step}\", fontsize=12)\n",
- " axes[0, col].set_xticks([])\n",
- " axes[0, col].set_yticks([])\n",
- "\n",
- " axes[1, col].imshow(truth_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin=origin)\n",
- " axes[1, col].set_xticks([])\n",
- " axes[1, col].set_yticks([])\n",
- "\n",
- "axes[0, 0].set_ylabel(\"Prediction\", fontsize=12)\n",
- "axes[1, 0].set_ylabel(\"Target\", fontsize=12)\n",
- "\n",
- "# single shared colorbar on the right\n",
- "fig.colorbar(im, ax=axes, location=\"right\", shrink=0.9, pad=0.02).set_label(\n",
- " \"value\", fontsize=12\n",
- ")\n",
- "\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "11e29014"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Prediction Power Spectrum Plot"
- ],
- "id": "b744c04c"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "## 2D Power Spectrum Over Time (top-K template freqs)\n",
- "\n",
- "# --- how many frequencies to track (includes DC if it ranks in top-K) ---\n",
- "num_freqs_to_track = 10 # change as you like\n",
- "\n",
- "# --- pick top-K frequencies from the template's 2D power (non-redundant half-plane) ---\n",
- "template_power_2d, _, _ = get_power_2d(template_2d) # shape (p1, p2)\n",
- "tracked_freqs = topk_template_freqs(template_2d, K=num_freqs_to_track)\n",
- "target_powers = {(kx, ky): template_power_2d[kx, ky] for (kx, ky) in tracked_freqs}\n",
- "\n",
- "# --- choose analysis steps (log-spaced) ---\n",
- "num_points = 100\n",
- "num_samples = 100\n",
- "T = len(param_hist_2d)\n",
- "steps_2d_analysis = np.unique(\n",
- " np.logspace(0, np.log10(max(T - 1, 1)), num_points, dtype=int)\n",
- ")\n",
- "steps_2d_analysis = steps_2d_analysis[steps_2d_analysis < T]\n",
- "steps_2d_analysis = np.insert(steps_2d_analysis, 0, 0)\n",
- "\n",
- "# --- track average output power at those frequencies over training ---\n",
- "powers_over_time_2d = {freq: [] for freq in tracked_freqs}\n",
- "\n",
- "with torch.no_grad():\n",
- " for step in tqdm(steps_2d_analysis, desc=\"Computing power spectra\"):\n",
- " rnn_2d.load_state_dict(param_hist_2d[step], strict=True)\n",
- " rnn_2d.eval()\n",
- "\n",
- " outputs_flat = (\n",
- " rnn_2d(X_seq_2d_t[:num_samples].to(device)).detach().cpu().numpy()\n",
- " ) # (N, p1*p2)\n",
- "\n",
- " # average power over batch\n",
- " # (simple loop for clarity; vectorize later if needed)\n",
- " powers_batch = []\n",
- " for i in range(outputs_flat.shape[0]):\n",
- " out_2d = outputs_flat[i].reshape(p1, p2)\n",
- " power_i, _, _ = get_power_2d(out_2d) # (p1, p2)\n",
- " powers_batch.append(power_i)\n",
- " avg_power = np.mean(powers_batch, axis=0) # (p1, p2)\n",
- "\n",
- " for kx, ky in tracked_freqs:\n",
- " powers_over_time_2d[(kx, ky)].append(avg_power[kx, ky])\n",
- "\n",
- "# --- convert lists to arrays ---\n",
- "for freq in tracked_freqs:\n",
- " powers_over_time_2d[freq] = np.array(powers_over_time_2d[freq])"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "43a6b047"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Plot prediction power spectrum over time"
- ],
- "id": "cbfd47eb"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "from matplotlib.ticker import FormatStrFormatter\n",
- "\n",
- "# Colors used for both the bands (top) and the mode curves (bottom)\n",
- "colors_2d = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- "\n",
- "# Make the figure a bit shorter and reduce vertical gap between plots\n",
- "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)\n",
- "fig.subplots_adjust(left=0.12, right=0.98, top=0.96, bottom=0.10, hspace=0.12)\n",
- "\n",
- "# ---------------- Top: training loss + theory lines ----------------\n",
- "x = np.arange(1, len(loss_hist_2d) + 1)\n",
- "ax1.plot(x, loss_hist_2d, lw=4, zorder=5)\n",
- "\n",
- "# Compute power spectrum of template for theory lines\n",
- "fy, fx, power = get_power_2d_adele(template_2d) # adapt unpacking if your func differs\n",
- "power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- "\n",
- "# Theory levels (cumulative tail sums)\n",
- "alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n",
- "coef = 1.0 / (p1 * p2)\n",
- "y_levels = coef * alpha_values # strictly decreasing\n",
- "\n",
- "# Shade horizontal bands between successive theory lines\n",
- "n_bands = min(len(tracked_freqs), len(y_levels) - 1)\n",
- "for i in range(n_bands):\n",
- " y_top = y_levels[i]\n",
- " y_bot = y_levels[i + 1]\n",
- " ax1.axhspan(y_bot, y_top, facecolor=colors_2d[i], alpha=0.15, zorder=-3)\n",
- "\n",
- "# Draw the black theory lines\n",
- "for y in y_levels[: n_bands + 1]:\n",
- " ax1.axhline(y=y, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
- "\n",
- "# ax1.set_ylim(3.5e-2, 1.3e-1)\n",
- "ax1.set_ylabel(\"Train Loss\", fontsize=24)\n",
- "style_axes(ax1)\n",
- "ax1.grid(False)\n",
- "ax1.tick_params(labelbottom=False) # only show x ticks on bottom plot\n",
- "\n",
- "# ---------------- Bottom: tracked mode power over time ----------------\n",
- "for i, (kx, ky) in enumerate(tracked_freqs):\n",
- " ax2.plot(steps_2d_analysis, powers_over_time_2d[(kx, ky)], color=colors_2d[i], lw=3)\n",
- " ax2.axhline(\n",
- " target_powers[(kx, ky)],\n",
- " color=colors_2d[i],\n",
- " linestyle=\"dotted\",\n",
- " linewidth=2,\n",
- " alpha=0.5,\n",
- " )\n",
- "\n",
- "ax2.set_xlabel(\"Epochs\", fontsize=20)\n",
- "ax2.set_ylabel(\"Power in Prediction\", fontsize=24)\n",
- "ax2.grid(True, alpha=0.3)\n",
- "style_axes(ax2)\n",
- "ax2.yaxis.set_major_formatter(FormatStrFormatter(\"%.1f\"))\n",
- "\n",
- "# No legend; plots are closer via reduced hspace and smaller figure height\n",
- "plt.savefig(\"train-loss.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "6ba47465-6ab8-41dd-8b2d-242a27217866"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Plot reference Fourier modes (irreps of G)"
- ],
- "id": "2f2b63f9"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "import os\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.patheffects as pe\n",
- "\n",
- "# Expected to exist:\n",
- "# - tracked_freqs : list[tuple[int,int]]\n",
- "# - p1, p2 : ints\n",
- "\n",
- "# Colors\n",
- "try:\n",
- " colors_2d\n",
- "except NameError:\n",
- " colors_2d = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- "\n",
- "\n",
- "def fourier_mode_2d(p1: int, p2: int, kx: int, ky: int, phase: float = 0.0):\n",
- " y = np.arange(p1)[:, None]\n",
- " x = np.arange(p2)[None, :]\n",
- " mode = np.cos(2 * np.pi * (ky * y / p1 + kx * x / p2) + phase)\n",
- " mmin, mmax = mode.min(), mode.max()\n",
- " return (mode - mmin) / (mmax - mmin) if mmax > mmin else mode\n",
- "\n",
- "\n",
- "def signed_k(k: int, n: int) -> int:\n",
- " return k if k <= n // 2 else k - n\n",
- "\n",
- "\n",
- "def pretty_k(k: int, n: int) -> str:\n",
- " if n % 2 == 0 and k == n // 2:\n",
- " return r\"\\pm{}\".format(n // 2)\n",
- " return f\"{signed_k(k, n)}\"\n",
- "\n",
- "\n",
- "# ---- Save each mode separately ----\n",
- "out_dir = \"fourier_modes\"\n",
- "os.makedirs(out_dir, exist_ok=True)\n",
- "\n",
- "for i, (kx, ky) in enumerate(tracked_freqs):\n",
- " img = fourier_mode_2d(p1, p2, kx, ky)\n",
- "\n",
- " fig, ax = plt.subplots(figsize=(3.2, 2.2))\n",
- " ax.imshow(img, cmap=\"RdBu_r\", origin=\"upper\")\n",
- " ax.set_xticks([])\n",
- " ax.set_yticks([])\n",
- "\n",
- " # Colored border\n",
- " for side in (\"left\", \"right\", \"top\", \"bottom\"):\n",
- " ax.spines[side].set_edgecolor(colors_2d[i])\n",
- " ax.spines[side].set_linewidth(8)\n",
- "\n",
- " # Big, colored, centered label\n",
- " kx_label = pretty_k(kx, p2)\n",
- " ky_label = pretty_k(ky, p1)\n",
- " ax.text(\n",
- " 0.5,\n",
- " 0.5,\n",
- " f\"$k=({kx_label},{ky_label})$\",\n",
- " color=colors_2d[i],\n",
- " fontsize=25,\n",
- " fontweight=\"bold\",\n",
- " ha=\"center\",\n",
- " va=\"center\",\n",
- " transform=ax.transAxes,\n",
- " path_effects=[pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.8)],\n",
- " )\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " # Filenames include both raw and signed indices\n",
- " kx_signed, ky_signed = signed_k(kx, p2), signed_k(ky, p1)\n",
- " base = f\"mode_{i:03d}_kx{kx}_ky{ky}_signed_{kx_signed}_{ky_signed}\"\n",
- " png_path = os.path.join(out_dir, base + \".png\")\n",
- " npy_path = os.path.join(out_dir, base + \".npy\")\n",
- "\n",
- " fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n",
- " plt.close(fig)\n",
- "\n",
- " # Optional: save the normalized array too\n",
- " np.save(npy_path, img)\n",
- "\n",
- "print(f\"Saved {len(tracked_freqs)} mode images (and .npy arrays) to: {out_dir}/\")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "6f3db659-353d-40b2-94cf-37ed2153f0b0"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# ---- Save one tall, vertically stacked image (labels on RIGHT) with extra whitespace\n",
- "# and safe padding so thick borders aren't trimmed ----\n",
- "stack_name = \"fourier_modes_stacked_rightlabel_spaced.png\" # or \".pdf\"\n",
- "stack_path = os.path.join(out_dir, stack_name)\n",
- "\n",
- "n = len(tracked_freqs)\n",
- "\n",
- "# Panel geometry and gap (increase gap_h_in for more whitespace)\n",
- "panel_h_in = 2.2\n",
- "gap_h_in = 0.35 # whitespace between rows\n",
- "# total width = image col + label col\n",
- "fig_w_in = 4.6\n",
- "fig_h_in = n * panel_h_in + (n - 1) * gap_h_in\n",
- "\n",
- "import matplotlib.gridspec as gridspec\n",
- "\n",
- "fig = plt.figure(figsize=(fig_w_in, fig_h_in), dpi=300)\n",
- "\n",
- "# Rows alternate: [panel, gap, panel, gap, ..., panel]\n",
- "rows = 2 * n - 1\n",
- "height_ratios = []\n",
- "for i in range(n):\n",
- " height_ratios.append(panel_h_in)\n",
- " if i < n - 1:\n",
- " height_ratios.append(gap_h_in)\n",
- "\n",
- "# Right-labeled: image on LEFT, label on RIGHT\n",
- "gs = gridspec.GridSpec(\n",
- " nrows=rows, ncols=2,\n",
- " width_ratios=[1.0, 0.46], # image : label (bump right col if labels are long)\n",
- " height_ratios=height_ratios,\n",
- " wspace=0.0, hspace=0.0\n",
- ")\n",
- "\n",
- "for i, (kx, ky) in enumerate(tracked_freqs):\n",
- " r = 2 * i # even rows are content; odd rows are spacers\n",
- "\n",
- " # Image axis (left)\n",
- " ax_img = fig.add_subplot(gs[r, 0])\n",
- " img = fourier_mode_2d(p1, p2, kx, ky)\n",
- " ax_img.imshow(img, cmap=\"RdBu_r\", origin=\"upper\", aspect=\"equal\")\n",
- " ax_img.set_xticks([]); ax_img.set_yticks([])\n",
- "\n",
- " # Colored border around the image only\n",
- " for side in (\"left\", \"right\", \"top\", \"bottom\"):\n",
- " ax_img.spines[side].set_edgecolor(colors_2d[i])\n",
- " ax_img.spines[side].set_linewidth(8)\n",
- "\n",
- " # Label axis (right)\n",
- " ax_label = fig.add_subplot(gs[r, 1])\n",
- " ax_label.set_axis_off()\n",
- " kx_label = pretty_k(kx, p2)\n",
- " ky_label = pretty_k(ky, p1)\n",
- " ax_label.text(\n",
- " 0.0, 0.5, f\"$k=({kx_label},{ky_label})$\",\n",
- " color=colors_2d[i], fontsize=45, fontweight=\"bold\",\n",
- " ha=\"left\", va=\"center\", transform=ax_label.transAxes,\n",
- " path_effects=[pe.withStroke(linewidth=3, foreground=\"white\", alpha=0.8)]\n",
- " )\n",
- "\n",
- "# Pull content slightly in from the canvas edges so spines aren't flush\n",
- "fig.subplots_adjust(left=0.02, right=0.98, top=0.985, bottom=0.015)\n",
- "\n",
- "# Save with a tiny pad so thick borders aren't clipped by 'tight' bbox\n",
- "fig.savefig(stack_path, dpi=300, bbox_inches=\"tight\", pad_inches=0.12)\n",
- "plt.close(fig)\n",
- "\n",
- "print(f\"Saved vertical stack with RIGHT labels and spacing (no edge clipping): {stack_path}\")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "996520ee-f19f-48a4-b453-a22274a91661"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## W_out Plot"
- ],
- "id": "5932a60e-e3d1-4011-946e-54dfae1a62f6"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# ========= Inputs you already have =========\n",
- "# param_hist_2d : list of state-dicts with key \"W_out\"\n",
- "# p1, p2 : reshape dims (p1 * p2 == D)\n",
- "\n",
- "# ========= Config =========\n",
- "steps_2d = [1, 5] # 100, len(param_hist_2d) - 1]\n",
- "dead_thresh_l2 = 0.25 # absolute L2-norm threshold for \"dead\" neurons\n",
- "heat_cmap = \"RdBu_r\" # colormap for heatmaps (weight values)\n",
- "border_lw = 5.0 # border width for neuron tiles\n",
- "title_fs = 18 # suptitle font size\n",
- "\n",
- "# Colors for categories\n",
- "palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- "dead_color = (0.6, 0.6, 0.6, 1.0)\n",
- "\n",
- "\n",
- "# ========= Tiny helpers =========\n",
- "def squareish_grid(n):\n",
- " c = int(np.ceil(np.sqrt(n)))\n",
- " r = int(np.ceil(n / c))\n",
- " return r, c\n",
- "\n",
- "\n",
- "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n",
- " \"\"\"Sum power at (kx,ky) and its real-signal mirror (-kx,-ky).\"\"\"\n",
- " i0, j0 = kx % p1, ky % p2\n",
- " i1, j1 = (-kx) % p1, (-ky) % p2\n",
- " if (i0, j0) == (i1, j1):\n",
- " return power2d[i0, j0]\n",
- " return power2d[i0, j0] + power2d[i1, j1]\n",
- "\n",
- "\n",
- "# ========= Prep: sizes and global color limits =========\n",
- "W0 = param_hist_2d[steps_2d[0]][\"W_out\"].detach().cpu().numpy().T # (H, D)\n",
- "H, D = W0.shape\n",
- "assert p1 * p2 == D, f\"p1*p2 ({p1*p2}) must equal D ({D}).\"\n",
- "\n",
- "vmin, vmax = np.inf, -np.inf\n",
- "for step in steps_2d:\n",
- " W = param_hist_2d[step][\"W_out\"].detach().cpu().numpy().T\n",
- " vmin = min(vmin, W.min())\n",
- " vmax = max(vmax, W.max())\n",
- "\n",
- "# ========= One figure per time step =========\n",
- "R_ner, C_ner = squareish_grid(H)\n",
- "tile_w, tile_h = 2, 2 # inches per neuron tile\n",
- "figsize = (C_ner * tile_w, R_ner * tile_h)\n",
- "\n",
- "for step in steps_2d:\n",
- " W = param_hist_2d[step][\"W_out\"].detach().cpu().numpy().T # (H, D)\n",
- "\n",
- " # Dominant tracked frequency + dead mask\n",
- " dom_idx = np.empty(H, dtype=int)\n",
- " l2 = np.linalg.norm(W, axis=1)\n",
- " dead_mask = l2 < dead_thresh_l2\n",
- "\n",
- " for j in range(H):\n",
- " m = W[j].reshape(p1, p2)\n",
- " F = np.fft.fft2(m)\n",
- " P = (F.conj() * F).real\n",
- " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n",
- " dom_idx[j] = int(np.argmax(tp))\n",
- "\n",
- " edge_colors = palette[dom_idx]\n",
- " edge_colors[dead_mask] = dead_color\n",
- "\n",
- " # Build figure for this time step\n",
- " fig = plt.figure(figsize=figsize)\n",
- " # fig.suptitle(f\"step {step}\", fontsize=title_fs, fontweight=\"bold\", y=0.98)\n",
- " gs = gridspec.GridSpec(R_ner, C_ner, figure=fig, wspace=0.06, hspace=0.06)\n",
- "\n",
- " # Plot tiles\n",
- " for j in range(R_ner * C_ner):\n",
- " ax = fig.add_subplot(gs[j // C_ner, j % C_ner])\n",
- " if j < H:\n",
- " m = W[j].reshape(p1, p2)\n",
- " im = ax.imshow(\n",
- " m, vmin=vmin, vmax=vmax, origin=\"lower\", aspect=\"equal\", cmap=heat_cmap\n",
- " )\n",
- " # border highlight\n",
- " ec = edge_colors[j]\n",
- " for sp in ax.spines.values():\n",
- " sp.set_edgecolor(ec)\n",
- " sp.set_linewidth(border_lw)\n",
- " else:\n",
- " ax.axis(\"off\")\n",
- "\n",
- " ax.set_xticks([])\n",
- " ax.set_yticks([])\n",
- " plt.savefig(f\"step {step}\", bbox_inches=\"tight\", dpi=200)\n",
- " plt.show()\n",
- "\n",
- "# ========= Standalone GLOBAL COLORBAR figure =========\n",
- "fig_cb = plt.figure(figsize=(4, 1.0)) # wide short bar\n",
- "ax_cb = fig_cb.add_axes([0.1, 0.35, 0.8, 0.3]) # [left, bottom, width, height]\n",
- "norm = Normalize(vmin=vmin, vmax=vmax)\n",
- "sm = cm.ScalarMappable(norm=norm, cmap=heat_cmap)\n",
- "cbar = fig_cb.colorbar(sm, cax=ax_cb, orientation=\"horizontal\")\n",
- "cbar.set_label(\"Weight value\")\n",
- "plt.show()\n",
- "\n",
- "# ========= Standalone LEGEND figure for dominant frequency =========\n",
- "fig_legend = plt.figure(figsize=(5, 1.6))\n",
- "ax_leg = fig_legend.add_subplot(111)\n",
- "ax_leg.axis(\"off\")\n",
- "\n",
- "# Use patches with COLORED EDGES (to match tile borders)\n",
- "handles = [\n",
- " Patch(\n",
- " facecolor=\"white\", edgecolor=palette[i], linewidth=2.5, label=tracked_freqs[i]\n",
- " )\n",
- " for i in range(len(tracked_freqs))\n",
- "]\n",
- "handles.append(\n",
- " Patch(facecolor=\"white\", edgecolor=dead_color, linewidth=2.5, label=\"dead\")\n",
- ")\n",
- "\n",
- "leg = ax_leg.legend(\n",
- " handles=handles, ncol=4, frameon=True, loc=\"center\", title=\"Dominant frequency\"\n",
- ")\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "44adb93e-a396-4b7c-8350-056934a49ddd"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## W_mix Plot"
- ],
- "id": "d3232d4a-d65e-4282-be33-00da055ebb98"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# =========================\n",
- "# Config\n",
- "# =========================\n",
- "within_group_order = \"phase\" # 'phase' | 'power' | 'phase_power' | 'none'\n",
- "outfile = \"Wmix_grouped_by_Wo_freq_labeled.pdf\"\n",
- "cmap = \"RdBu_r\"\n",
- "\n",
- "TITLE_FONTSIZE = 18\n",
- "LABEL_FONTSIZE = 11\n",
- "CBAR_LABEL_SIZE = 12\n",
- "CBAR_TICK_SIZE = 11\n",
- "SEPARATOR_LW = 0.9\n",
- "BLOCK_EDGE_LW = 2.0\n",
- "\n",
- "dead_l2_thresh = 0.1\n",
- "dead_position = \"last\"\n",
- "\n",
- "# how many template freqs to track/label\n",
- "num_freqs_to_track = 7 # e.g., like your old list length\n",
- "\n",
- "# You must have template_2d (shape p1 x p2) defined.\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "\n",
- "# =========================\n",
- "# Helpers (2D, top-K tracked)\n",
- "# =========================\n",
- "def topk_template_freqs(template_2d: np.ndarray, K: int, min_power: float = 1e-20):\n",
- " # Use your get_power_2d_adele(template_2d) -> (freqs_u, freqs_v, power)\n",
- " freqs_u, freqs_v, power = get_power_2d_adele(\n",
- " template_2d\n",
- " ) # power: (p1, p2//2+1) or similar\n",
- " shp = power.shape\n",
- " flat = power.ravel()\n",
- " mask = flat > min_power\n",
- " if not np.any(mask):\n",
- " return []\n",
- " top_idx = np.flatnonzero(mask)[np.argsort(flat[mask])[::-1]][:K]\n",
- " kx, ky = np.unravel_index(top_idx, shp)\n",
- " return list(zip(kx.tolist(), ky.tolist()))\n",
- "\n",
- "\n",
- "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n",
- " i0, j0 = kx % p1, ky % p2\n",
- " i1, j1 = (-kx) % p1, (-ky) % p2\n",
- " if (i0, j0) == (i1, j1):\n",
- " return float(power2d[i0, j0])\n",
- " return float(power2d[i0, j0] + power2d[i1, j1])\n",
- "\n",
- "\n",
- "def analyze_Wo_tracked(sd, tracked_freqs, p1, p2):\n",
- " \"\"\"\n",
- " For each neuron row of W_o, find the dominant freq among tracked_freqs.\n",
- " Returns:\n",
- " dom_idx (index into tracked_freqs), phase (at rep bin), dom_power, l2, D\n",
- " \"\"\"\n",
- " Wo = sd[\"W_out\"].detach().cpu().numpy() # (p, H) with p = p1*p2\n",
- " W = Wo.T # (H, p)\n",
- " H, D = W.shape\n",
- " assert D == p1 * p2\n",
- "\n",
- " dom_idx = np.empty(H, dtype=int)\n",
- " dom_pow = np.empty(H, dtype=float)\n",
- " phase = np.empty(H, dtype=float)\n",
- " l2 = np.linalg.norm(W, axis=1)\n",
- "\n",
- " for j in range(H):\n",
- " m = W[j].reshape(p1, p2)\n",
- " F = np.fft.fft2(m)\n",
- " P = (F.conj() * F).real\n",
- " # power only at tracked bins (with symmetry accounted)\n",
- " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n",
- " jj = int(np.argmax(tp))\n",
- " dom_idx[j] = jj\n",
- " # phase at representative positive bin\n",
- " i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2\n",
- " phase[j] = np.angle(F[i0, j0])\n",
- " dom_pow[j] = tp[jj]\n",
- " return dom_idx, phase, dom_pow, l2, D\n",
- "\n",
- "\n",
- "def permutation_from_groups_with_dead(\n",
- " dom_idx,\n",
- " phase,\n",
- " dom_power,\n",
- " l2,\n",
- " *,\n",
- " within=\"phase\",\n",
- " dead_l2_thresh=1e-1,\n",
- " dead_position=\"last\",\n",
- "):\n",
- " dead_mask = l2 < float(dead_l2_thresh)\n",
- " groups = {}\n",
- " for i, f in enumerate(dom_idx):\n",
- " key = -1 if dead_mask[i] else int(f) # f is index into tracked_freqs\n",
- " groups.setdefault(key, []).append(i)\n",
- "\n",
- " freq_keys = sorted([k for k in groups.keys() if k >= 0])\n",
- " ordered_keys = (freq_keys + [-1]) if dead_position == \"last\" else ([-1] + freq_keys)\n",
- " group_keys = [k for k in ordered_keys if k in groups]\n",
- "\n",
- " perm, boundaries = [], []\n",
- " for f in group_keys:\n",
- " idxs = groups[f]\n",
- " if f == -1:\n",
- " idxs = sorted(idxs, key=lambda i: l2[i])\n",
- " else:\n",
- " if within == \"phase\" and phase is not None:\n",
- " idxs = sorted(idxs, key=lambda i: (phase[i] + 2 * np.pi) % (2 * np.pi))\n",
- " elif within == \"power\" and dom_power is not None:\n",
- " idxs = sorted(idxs, key=lambda i: -dom_power[i])\n",
- " elif (\n",
- " within == \"phase_power\" and phase is not None and dom_power is not None\n",
- " ):\n",
- " idxs = sorted(\n",
- " idxs,\n",
- " key=lambda i: ((phase[i] + 2 * np.pi) % (2 * np.pi), -dom_power[i]),\n",
- " )\n",
- " perm.extend(idxs)\n",
- " boundaries.append(len(perm))\n",
- " return np.array(perm, dtype=int), group_keys, boundaries\n",
- "\n",
- "\n",
- "def reorder_square(M, perm):\n",
- " return M[perm][:, perm]\n",
- "\n",
- "\n",
- "# labels & colors for the tracked list\n",
- "tracked_freqs = topk_template_freqs(template_2d, num_freqs_to_track)\n",
- "tracked_labels = [\n",
- " (\"DC\" if (kx, ky) == (0, 0) else f\"({kx},{ky})\") for (kx, ky) in tracked_freqs\n",
- "]\n",
- "\n",
- "\n",
- "def build_freq_colors_tracked(n):\n",
- " return plt.cm.tab10(np.linspace(0, 1, n))\n",
- "\n",
- "\n",
- "freq_colors = build_freq_colors_tracked(len(tracked_freqs))\n",
- "dead_gray = \"0.35\"\n",
- "\n",
- "\n",
- "def add_group_labels_top(\n",
- " ax, group_keys, boundaries, *, show_counts=True, rotation=0, fontsize=LABEL_FONTSIZE\n",
- "):\n",
- " starts = [0] + boundaries[:-1]\n",
- " ends = [b - 1 for b in boundaries]\n",
- " centers = [(s + e) / 2.0 for s, e in zip(starts, ends)]\n",
- " sizes = [e - s + 1 for s, e in zip(starts, ends)]\n",
- "\n",
- " labels = []\n",
- " colors = []\n",
- " for kk, nn in zip(group_keys, sizes):\n",
- " if kk == -1:\n",
- " base = \"DEAD\"\n",
- " clr = dead_gray\n",
- " else:\n",
- " base = tracked_labels[kk]\n",
- " clr = freq_colors[kk]\n",
- " labels.append(f\"{base}\\n(n={nn})\" if show_counts else base)\n",
- " colors.append(clr)\n",
- "\n",
- " ax.set_xticks(centers)\n",
- " ax.set_xticklabels(labels, rotation=rotation, fontsize=fontsize, ha=\"center\")\n",
- " ax.tick_params(\n",
- " axis=\"x\",\n",
- " bottom=False,\n",
- " top=True,\n",
- " labelbottom=False,\n",
- " labeltop=True,\n",
- " labelsize=fontsize,\n",
- " )\n",
- " for lbl, clr in zip(ax.get_xticklabels(), colors):\n",
- " lbl.set_color(clr)\n",
- "\n",
- "\n",
- "# =========================\n",
- "# Prepare steps & snapshots\n",
- "# =========================\n",
- "steps = [1, 5] # [50, 100, len(param_hist_2d) - 1]\n",
- "Wh_perm_list = []\n",
- "group_info_list = []\n",
- "\n",
- "for s in steps:\n",
- " sd = param_hist_2d[s]\n",
- "\n",
- " # analyze W_o against tracked 2D freqs\n",
- " dom_idx, phase, dom_power, l2, D = analyze_Wo_tracked(sd, tracked_freqs, p1, p2)\n",
- "\n",
- " # W_mix fallback to W_h\n",
- " if \"W_mix\" in sd:\n",
- " M = sd[\"W_mix\"].detach().cpu().numpy()\n",
- " elif \"W_h\" in sd:\n",
- " M = sd[\"W_h\"].detach().cpu().numpy()\n",
- " else:\n",
- " raise KeyError(\"Neither 'W_mix' nor 'W_h' found in state dict.\")\n",
- "\n",
- " perm, group_keys, boundaries = permutation_from_groups_with_dead(\n",
- " dom_idx,\n",
- " phase,\n",
- " dom_power,\n",
- " l2,\n",
- " within=within_group_order,\n",
- " dead_l2_thresh=dead_l2_thresh,\n",
- " dead_position=dead_position,\n",
- " )\n",
- "\n",
- " M_perm = reorder_square(M, perm)\n",
- " Wh_perm_list.append(M_perm)\n",
- " group_info_list.append((group_keys, boundaries))\n",
- "\n",
- "# Shared symmetric color limits\n",
- "vmax = max(np.max(np.abs(M)) for M in Wh_perm_list)\n",
- "vmin = -vmax if vmax > 0 else 0.0\n",
- "\n",
- "# =========================\n",
- "# Plot\n",
- "# =========================\n",
- "n = len(steps)\n",
- "fig, axes = plt.subplots(1, n, figsize=(3.8 * n, 3.8), constrained_layout=True)\n",
- "try:\n",
- " fig.set_constrained_layout_pads(\n",
- " w_pad=0.003, h_pad=0.003, wspace=0.003, hspace=0.003\n",
- " )\n",
- "except Exception:\n",
- " pass\n",
- "if n == 1:\n",
- " axes = [axes]\n",
- "\n",
- "im = None\n",
- "for j, (s, M_perm) in enumerate(zip(steps, Wh_perm_list)):\n",
- " ax = axes[j]\n",
- " im = ax.imshow(\n",
- " M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect=\"equal\", interpolation=\"nearest\"\n",
- " )\n",
- "\n",
- " ax.set_yticks([])\n",
- " ax.tick_params(axis=\"x\", bottom=False)\n",
- "\n",
- " group_keys, boundaries = group_info_list[j]\n",
- " for b in boundaries[:-1]:\n",
- " ax.axhline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n",
- " ax.axvline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n",
- "\n",
- " starts = [0] + boundaries[:-1]\n",
- " ends = [b - 1 for b in boundaries]\n",
- " for kk, s0, e0 in zip(group_keys, starts, ends):\n",
- " if kk == -1: # skip DEAD box color\n",
- " continue\n",
- " size = e0 - s0 + 1\n",
- " rect = Rectangle(\n",
- " (s0 - 0.5, s0 - 0.5),\n",
- " width=size,\n",
- " height=size,\n",
- " fill=False,\n",
- " linewidth=BLOCK_EDGE_LW,\n",
- " edgecolor=freq_colors[kk],\n",
- " alpha=0.95,\n",
- " joinstyle=\"miter\",\n",
- " )\n",
- " ax.add_patch(rect)\n",
- "\n",
- " add_group_labels_top(\n",
- " ax,\n",
- " group_keys,\n",
- " boundaries,\n",
- " show_counts=True,\n",
- " rotation=0,\n",
- " fontsize=LABEL_FONTSIZE,\n",
- " )\n",
- " ax.set_xlabel(f\"t = {s}\", fontsize=TITLE_FONTSIZE, labelpad=8)\n",
- "\n",
- "# shared colorbar\n",
- "cbar = fig.colorbar(im, ax=axes, shrink=1.0, pad=0.012, aspect=18)\n",
- "cbar.ax.tick_params(labelsize=CBAR_TICK_SIZE)\n",
- "cbar.set_label(\"Weight value\", fontsize=CBAR_LABEL_SIZE)\n",
- "\n",
- "plt.savefig(outfile, bbox_inches=\"tight\", dpi=200)\n",
- "plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "6caf85e2-4706-4695-8499-16aad2a3061d"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# =========================\n",
- "# Save tight heatmaps (keep separators + rectangles; no labels/legend)\n",
- "# =========================\n",
- "import os\n",
- "from matplotlib.patches import Rectangle\n",
- "\n",
- "base, ext = os.path.splitext(outfile)\n",
- "if not ext:\n",
- " ext = \".png\" # change to \".pdf\" if preferred\n",
- "\n",
- "# Shared symmetric color limits across steps\n",
- "vmax = max(np.max(np.abs(M)) for M in Wh_perm_list)\n",
- "vmin = -vmax if vmax > 0 else 0.0\n",
- "\n",
- "for j, (s, M_perm) in enumerate(zip(steps, Wh_perm_list)):\n",
- " # Full-bleed figure (no margins)\n",
- " fig = plt.figure(figsize=(3.2, 3.2), dpi=200)\n",
- " ax = plt.Axes(fig, [0, 0, 1, 1]) # left, bottom, width, height (normalized)\n",
- " fig.add_axes(ax)\n",
- " ax.set_axis_off() # no ticks/labels/frame\n",
- "\n",
- " # Heatmap\n",
- " ax.imshow(\n",
- " M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect=\"equal\", interpolation=\"nearest\"\n",
- " )\n",
- "\n",
- " # Group separators + rectangles\n",
- " group_keys, boundaries = group_info_list[j]\n",
- "\n",
- " # Thin separator lines between groups\n",
- " for b in boundaries[:-1]:\n",
- " ax.axhline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n",
- " ax.axvline(b - 0.5, color=\"k\", lw=SEPARATOR_LW, alpha=0.65)\n",
- "\n",
- " # Colored block outlines for each non-DEAD group\n",
- " starts = [0] + boundaries[:-1]\n",
- " ends = [b - 1 for b in boundaries]\n",
- " for kk, s0, e0 in zip(group_keys, starts, ends):\n",
- " if kk == -1:\n",
- " continue # skip DEAD group outline\n",
- " size = e0 - s0 + 1\n",
- " rect = Rectangle(\n",
- " (s0 - 0.5, s0 - 0.5),\n",
- " width=size,\n",
- " height=size,\n",
- " fill=False,\n",
- " linewidth=BLOCK_EDGE_LW,\n",
- " edgecolor=freq_colors[kk],\n",
- " alpha=0.95,\n",
- " joinstyle=\"miter\",\n",
- " )\n",
- " ax.add_patch(rect)\n",
- "\n",
- " # Save ultra-tight (no padding, no legend/colorbar/labels)\n",
- " per_step_outfile = f\"{base}_t{s:04d}_tight{ext}\"\n",
- " fig.savefig(per_step_outfile, bbox_inches=\"tight\", pad_inches=0)\n",
- " plt.close(fig)\n",
- "\n",
- "print(\n",
- " f\"Saved {len(steps)} files like '{base}_t####_tight{ext}' (tight, no labels/legend).\"\n",
- ")"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "acd308eb-032e-46e2-8a55-582e2d06a793"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Phase Plot"
- ],
- "id": "3493ec18-299a-4642-94ec-a13723c80882"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "# One figure with a POLAR SUBPLOT per module (tracked (kx, ky)), excluding DEAD and DC-dominant.\n",
- "# Uses the final checkpoint only (t = -1).\n",
- "\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from matplotlib.lines import Line2D\n",
- "\n",
- "dead_thresh_l2 = 0.5 # absolute L2 threshold for \"dead\" neurons\n",
- "\n",
- "# palette (one color per tracked freq)\n",
- "if 'palette' not in locals() or len(palette) != len(tracked_freqs):\n",
- " palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- "\n",
- "# ---------- helpers ----------\n",
- "def squareish_grid(n):\n",
- " c = int(np.ceil(np.sqrt(n)))\n",
- " r = int(np.ceil(n / c))\n",
- " return r, c\n",
- "\n",
- "def tracked_power_from_fft2(P, kx, ky, p1, p2):\n",
- " i0, j0 = kx % p1, ky % p2\n",
- " i1, j1 = (-kx) % p1, (-ky) % p2\n",
- " return P[i0, j0] if (i0, j0) == (i1, j1) else P[i0, j0] + P[i1, j1]\n",
- "\n",
- "def canonical_bin(kx, ky, p1, p2):\n",
- " i, j = kx % p1, ky % p2\n",
- " if j > p2 - j: # prefer upper half-plane in ky\n",
- " i, j = (-kx) % p1, (-ky) % p2\n",
- " if j == 0 or (p2 % 2 == 0 and j == p2 // 2): # ky edge: fold kx too\n",
- " if i > p1 - i:\n",
- " i = (-kx) % p1\n",
- " return i, j\n",
- "\n",
- "def phase_at_2d_bin(m2d, kx, ky, p1, p2):\n",
- " F = np.fft.fft2(m2d)\n",
- " i, j = canonical_bin(kx, ky, p1, p2)\n",
- " return float(np.angle(F[i, j])) # [-pi, pi]\n",
- "\n",
- "# ---------- final step only ----------\n",
- "sd = param_hist_2d[-1]\n",
- "Wo = sd[\"W_out\"].detach().cpu().numpy() # (p1*p2, H) or (H, p1*p2)\n",
- "if Wo.shape[0] == p1 * p2:\n",
- " W = Wo.T\n",
- "elif Wo.shape[1] == p1 * p2:\n",
- " W = Wo\n",
- "else:\n",
- " raise ValueError(f\"W_o has incompatible shape {Wo.shape}; expected one dim == p1*p2={p1*p2}\")\n",
- "H, D = W.shape\n",
- "assert D == p1 * p2\n",
- "\n",
- "# per-neuron dominant tracked freq, phase, radius\n",
- "dom_idx = np.full(H, -1, dtype=int)\n",
- "phase = np.full(H, np.nan, dtype=float)\n",
- "radius = np.linalg.norm(W, axis=1)\n",
- "dead = radius < dead_thresh_l2\n",
- "\n",
- "for i in range(H):\n",
- " if dead[i]:\n",
- " continue\n",
- " m = W[i].reshape(p1, p2)\n",
- " F = np.fft.fft2(m); P = (F.conj() * F).real\n",
- " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n",
- " j = int(np.argmax(tp))\n",
- " kx, ky = tracked_freqs[j]\n",
- " if (kx, ky) == (0, 0): # DC has undefined phase \u2192 skip\n",
- " continue\n",
- " dom_idx[i] = j\n",
- " phase[i] = phase_at_2d_bin(m, kx, ky, p1, p2)\n",
- "\n",
- "# ---------- build module masks (exclude empty and DEAD) ----------\n",
- "valid = (~dead) & np.isfinite(phase)\n",
- "modules = [(j, (valid & (dom_idx == j))) for j in range(len(tracked_freqs))]\n",
- "modules = [(j, msk) for (j, msk) in modules if np.any(msk)] # keep non-empty modules only\n",
- "\n",
- "if len(modules) == 0:\n",
- " print(\"No non-dead, non-DC neurons found for any tracked module.\")\n",
- "else:\n",
- " # global radial limit for comparability across subplots\n",
- " rmax_global = max(float(radius[msk].max()) for _, msk in modules)\n",
- "\n",
- " R, C = squareish_grid(len(modules))\n",
- " fig, axes = plt.subplots(R, C, figsize=(3.8*C + 3.0, 4.2*R), subplot_kw={\"projection\": \"polar\"})\n",
- " axes = np.atleast_1d(axes).ravel()\n",
- "\n",
- " for ax, (j, msk) in zip(axes, modules):\n",
- " th = phase[msk]\n",
- " r = radius[msk]\n",
- "\n",
- " # sticks + markers\n",
- " for th_i, r_i in zip(th, r):\n",
- " ax.plot([th_i, th_i], [0.0, r_i], color=palette[j], linewidth=1.8, alpha=0.9)\n",
- " ax.scatter(th, r, s=70, color=palette[j], edgecolors='none', alpha=0.95)\n",
- "\n",
- " # styling\n",
- " ax.set_title(f\"{tracked_freqs[j]} | n={np.count_nonzero(msk)}\", fontsize=11, pad=8)\n",
- " ax.set_theta_zero_location(\"E\")\n",
- " ax.set_thetagrids(\n",
- " np.arange(0, 360, 90), # sparser grid to reduce clutter\n",
- " [\"0\", r\"$\\pi/2$\", r\"$\\pi$\", r\"$3\\pi/2$\"]\n",
- " )\n",
- " ax.set_yticklabels([])\n",
- " ax.spines[\"polar\"].set_linewidth(2)\n",
- " ax.set_rlim(0, 1.05 * rmax_global if rmax_global > 0 else 1.0)\n",
- "\n",
- " # hide any unused axes\n",
- " for ax in axes[len(modules):]:\n",
- " ax.set_visible(False)\n",
- "\n",
- " # legend with only shown modules, outside on the right\n",
- " handles = [Line2D([0],[0], marker='o', linestyle='', color=palette[j],\n",
- " label=f\"{tracked_freqs[j]}\", markersize=7)\n",
- " for (j, _) in modules]\n",
- " fig.legend(handles=handles, loc='center left', bbox_to_anchor=(1.02, 0.5),\n",
- " frameon=True, title=\"module (kx, ky)\")\n",
- "\n",
- " fig.subplots_adjust(right=0.82, wspace=0.35, hspace=0.35)\n",
- " # plt.savefig(\"phase_polar_modules_grid.png\", bbox_inches=\"tight\", dpi=170)\n",
- " plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "d0a2b283-bb8b-4145-9519-3af401532cc4"
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Activation Plot"
- ],
- "id": "97cdde1e-d4ef-4437-bc1e-5584ab50fd64"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [
- "paths_xy = sequence_to_paths_xy(sequence_xy, p1, p2)\n",
- "\n",
- "# ========= Config =========\n",
- "dead_l2_thresh = 1.0 # absolute L2 threshold (on W_o rows) for DEAD\n",
- "if 'palette' not in locals() or len(palette) != len(tracked_freqs):\n",
- " palette = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- "dead_color = (0.6, 0.6, 0.6, 1.0)\n",
- "\n",
- "tile_w, tile_h = 1.9, 1.9 # inches per neuron tile\n",
- "border_lw = 2.0\n",
- "title_fs = 14\n",
- "cmap_heat = \"viridis\"\n",
- "max_neurons_per_module = 9 # 3x3 grid\n",
- "\n",
- "# ========= Unroll hidden activations (bias-free quadratic RNN) =========\n",
- "@torch.no_grad()\n",
- "def unroll_hidden_batch_quadratic_no_bias(x_seq_t, rnn, batch_size=4096):\n",
- " \"\"\"\n",
- " For QuadraticRNN:\n",
- " h0 = template @ W_in^T\n",
- " h_t = (W_mix h_{t-1} + W_drive x_t)^2 for t = 0..k-1\n",
- " Returns: H_all of shape (N, k, d) with the hidden AFTER each update t.\n",
- " \"\"\"\n",
- " N, K, p = x_seq_t.shape\n",
- " d = rnn.W_mix.shape[0]\n",
- " H_all = torch.empty((N, K, d), device=x_seq_t.device, dtype=rnn.W_mix.dtype)\n",
- "\n",
- " # template -> (p,) then h0_vec -> (d,)\n",
- " tmpl = rnn.template\n",
- " if tmpl.dim() == 2 and tmpl.size(0) == 1:\n",
- " tmpl = tmpl.squeeze(0)\n",
- " h0_vec = tmpl @ rnn.W_in.T # (d,)\n",
- "\n",
- " for start in range(0, N, batch_size):\n",
- " end = min(N, start + batch_size)\n",
- " xb = x_seq_t[start:end] # (B, K, p)\n",
- " B = xb.size(0)\n",
- "\n",
- " # expand h0 across the batch\n",
- " h = h0_vec.expand(B, -1).contiguous() # (B, d)\n",
- "\n",
- " for t in range(0, K):\n",
- " xt = xb[:, t, :] # (B, p)\n",
- " pre = (h @ rnn.W_mix.T) + (xt @ rnn.W_drive.T)\n",
- " h = pre.pow(2)\n",
- " H_all[start:end, t, :] = h\n",
- "\n",
- " return H_all\n",
- "\n",
- "\n",
- "# ========= Collect activations =========\n",
- "if 'device' not in locals():\n",
- " device = next(rnn_2d.parameters()).device\n",
- "X_seq_t = torch.tensor(X_seq_2d, dtype=torch.float32, device=device)\n",
- "rnn_2d.to(device).eval()\n",
- "with torch.no_grad():\n",
- " H_all = unroll_hidden_batch_quadratic_no_bias(X_seq_t, rnn_2d, batch_size=4096) # (N, k, H)\n",
- "\n",
- "# ========= Prep numpy views =========\n",
- "H_all_np = H_all.detach().cpu().numpy() # (N, k, H)\n",
- "pos_xy = paths_xy.reshape(-1, 2) # (N*k, 2)\n",
- "pos_lin = pos_xy[:, 0] * p2 + pos_xy[:, 1] # linearized index in [0..p1*p2-1]\n",
- "acts_flat = H_all_np.reshape(-1, H_all_np.shape[2]) # (N*k, H)\n",
- "\n",
- "# Precompute counts per 2D position (shared by all neurons)\n",
- "counts_lin = np.bincount(pos_lin, minlength=p1*p2).astype(np.int64)\n",
- "counts = counts_lin.reshape(p1, p2)\n",
- "\n",
- "# ========= Module assignment from W_o using tracked 2D frequencies =========\n",
- "Wo = rnn_2d.W_out.detach().cpu().numpy() # (p1*p2, H)\n",
- "W = Wo.T # (H, p1*p2)\n",
- "Hdim = W.shape[0]\n",
- "l2_norm = np.linalg.norm(W, axis=1)\n",
- "is_dead = (l2_norm < dead_l2_thresh)\n",
- "\n",
- "def tracked_power_from_fft2(power2d, kx, ky, p1, p2):\n",
- " i0, j0 = kx % p1, ky % p2\n",
- " i1, j1 = (-kx) % p1, (-ky) % p2\n",
- " if (i0, j0) == (i1, j1):\n",
- " return power2d[i0, j0]\n",
- " return power2d[i0, j0] + power2d[i1, j1]\n",
- "\n",
- "dom_idx = np.empty(Hdim, dtype=int) # index into tracked_freqs\n",
- "dom_pw = np.empty(Hdim, dtype=float)\n",
- "phase = np.empty(Hdim, dtype=float) # phase at representative (i0, j0)\n",
- "\n",
- "for j in range(Hdim):\n",
- " m = W[j].reshape(p1, p2)\n",
- " F = np.fft.fft2(m)\n",
- " P = (F.conj() * F).real\n",
- " tp = [tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs]\n",
- " jj = int(np.argmax(tp))\n",
- " dom_idx[j] = jj\n",
- " # phase at a consistent representative bin\n",
- " i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2\n",
- " phase[j] = np.angle(F[i0, j0])\n",
- " dom_pw[j] = tp[jj]\n",
- "\n",
- "# Assign module id (DEAD = -1)\n",
- "module_id = np.where(is_dead, -1, dom_idx)\n",
- "\n",
- "# Group neurons by module (dead last)\n",
- "groups = {}\n",
- "for nid, mid in enumerate(module_id):\n",
- " groups.setdefault(int(mid), []).append(nid)\n",
- "\n",
- "freq_keys = [i for i in range(len(tracked_freqs)) if i in groups] # keep tracked order\n",
- "ordered_mods = freq_keys + ([-1] if -1 in groups else [])\n",
- "\n",
- "# ========= Pick up to 9 neurons/module (stable) =========\n",
- "def pick_neurons_for_module(m, max_neurons=9):\n",
- " idxs = groups[m]\n",
- " if m == -1:\n",
- " # DEAD: weakest first\n",
- " idxs = sorted(idxs, key=lambda i: l2_norm[i])\n",
- " else:\n",
- " # sort by (phase mod 2\u03c0, then -dom_power) for reproducibility\n",
- " idxs = sorted(idxs, key=lambda i: ((phase[i] + 2*np.pi) % (2*np.pi), -dom_pw[i]))\n",
- " return idxs[:min(max_neurons, len(idxs))]\n",
- "\n",
- "picked = {m: pick_neurons_for_module(m, max_neurons_per_module) for m in ordered_mods}\n",
- "\n",
- "# ========= Build mean activation map for a neuron =========\n",
- "# Efficient: single bincount per neuron; reuse global counts\n",
- "def mean_activation_map_for_neuron(nid: int) -> np.ndarray:\n",
- " sums_lin = np.bincount(pos_lin, weights=acts_flat[:, nid], minlength=p1*p2)\n",
- " sums = sums_lin.reshape(p1, p2)\n",
- " return sums / np.maximum(counts, 1) # avoid div-by-zero\n",
- "\n",
- "# ========= Plot: one 3\u00d73 figure per module (colorbar on far right) =========\n",
- "for m in ordered_mods:\n",
- " nids = picked[m]\n",
- " if len(nids) == 0:\n",
- " continue\n",
- "\n",
- " mlabel = (\"DEAD\" if m == -1 else tracked_freqs[m])\n",
- " mcolor = (dead_color if m == -1 else palette[m])\n",
- " safe = str(mlabel).replace(\"(\", \"\").replace(\")\", \"\").replace(\",\", \"_\").replace(\" \", \"\")\n",
- "\n",
- " # Compute maps for this module\n",
- " maps = [mean_activation_map_for_neuron(nid) for nid in nids]\n",
- "\n",
- " # ---- LOG SCALE: choose vmin as smallest positive, vmax as max ----\n",
- " vmax = max(float(mp.max()) for mp in maps)\n",
- " posmins = [float(mp[mp > 0].min()) for mp in maps if np.any(mp > 0)]\n",
- " use_log = (len(posmins) > 0) and (vmax > 0)\n",
- " if use_log:\n",
- " vmin_pos = min(posmins)\n",
- " norm = LogNorm(vmin=vmin_pos, vmax=vmax)\n",
- " else:\n",
- " # fallback to linear if no positive values\n",
- " vmin = min(float(mp.min()) for mp in maps)\n",
- " norm = None # linear\n",
- "\n",
- " nrows, ncols = 3, 3\n",
- " fig = plt.figure(figsize=(ncols * tile_w + 1.2, nrows * tile_h), constrained_layout=True)\n",
- " gs = fig.add_gridspec(nrows=nrows, ncols=ncols + 1, width_ratios=[1, 1, 1, 0.06], wspace=0.08, hspace=0.08)\n",
- "\n",
- " axes_r, last_im = [], None\n",
- " for r in range(nrows):\n",
- " for c in range(ncols):\n",
- " ax = fig.add_subplot(gs[r, c]); axes_r.append(ax)\n",
- "\n",
- " for i, ax in enumerate(axes_r):\n",
- " if i < len(nids):\n",
- " if use_log:\n",
- " im = ax.imshow(maps[i], norm=norm, origin=\"lower\", aspect=\"equal\", cmap=cmap_heat)\n",
- " # Optional: make zeros clearly visible\n",
- " # im.cmap.set_under('white')\n",
- " else:\n",
- " im = ax.imshow(maps[i], vmin=vmin, vmax=vmax, origin=\"lower\", aspect=\"equal\", cmap=cmap_heat)\n",
- " last_im = im\n",
- " ax.set_title(f\"h={nids[i]}\", fontsize=10, color=mcolor)\n",
- " for sp in ax.spines.values():\n",
- " sp.set_edgecolor(mcolor); sp.set_linewidth(border_lw)\n",
- " else:\n",
- " ax.axis(\"off\")\n",
- " ax.set_xticks([]); ax.set_yticks([])\n",
- "\n",
- " # Colorbar on the far right\n",
- " cax = fig.add_subplot(gs[:, -1])\n",
- " if last_im is not None:\n",
- " cbar = fig.colorbar(last_im, cax=cax)\n",
- " cbar.set_label(\"Mean activation (log scale)\" if use_log else \"Mean activation\", fontsize=11)\n",
- " # Optional: nice log ticks\n",
- " # if use_log:\n",
- " # import numpy as np\n",
- " # ticks = np.logspace(np.log10(vmin_pos), np.log10(vmax), num=5)\n",
- " # cbar.set_ticks(ticks)\n",
- "\n",
- " fig.suptitle(\n",
- " f\"Module {mlabel} | p1={p1}, p2={p2} (showing {len(nids)} of {len(groups[m])} neurons)\",\n",
- " fontsize=title_fs, color=mcolor\n",
- " )\n",
- " fig.supxlabel(\"position y\", fontsize=12)\n",
- " fig.supylabel(\"position x\", fontsize=12)\n",
- "\n",
- " plt.savefig(f\"activations2D_module_{safe}_3x3.png\", bbox_inches=\"tight\", dpi=170)\n",
- " plt.show()"
- ],
- "execution_count": null,
- "outputs": [],
- "id": "667c8114-4a0b-480e-9ba2-022122926a8d"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [],
- "execution_count": null,
- "outputs": [],
- "id": "be86d544-02bf-4623-b57b-15b16f3248a5"
- },
- {
- "cell_type": "code",
- "metadata": {},
- "source": [],
- "execution_count": null,
- "outputs": [],
- "id": "b8cbe5d1-9668-405f-b452-ab92dcad91ef"
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
\ No newline at end of file
diff --git a/notebooks/C_n.ipynb b/notebooks/C_n.ipynb
deleted file mode 100644
index 82a1833..0000000
--- a/notebooks/C_n.ipynb
+++ /dev/null
@@ -1,804 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "51d11caf-0971-4324-b63b-819b714a9c3c",
- "metadata": {},
- "source": [
- "# Binary Group Composition with $C_n$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from tqdm import tqdm\n",
- "import torch.optim as optim\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "import seaborn as sns\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "from matplotlib.animation import FuncAnimation\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "from matplotlib.ticker import FuncFormatter\n",
- "from matplotlib.ticker import MaxNLocator"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6",
- "metadata": {},
- "source": [
- "## Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88",
- "metadata": {},
- "outputs": [],
- "source": [
- "def one_hot(p):\n",
- " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n",
- " vec = np.zeros(p)\n",
- " vec[0] = 1\n",
- " return vec\n",
- "\n",
- "def generate_template(p, magnitude, exponent):\n",
- " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n",
- " template = np.ones(p) # Base term (DC component)\n",
- " for freq in range(1, p):\n",
- " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n",
- " return template / p\n",
- "\n",
- "def generate_fixed_template(p):\n",
- " # Generate template array from Fourier spectrum\n",
- " spectrum = np.zeros(p, dtype=complex)\n",
- " \n",
- " # Set only three frequencies with specific amplitudes\n",
- " spectrum[1] = 12.5 # Positive frequency\n",
- " spectrum[-1] = 12.5 # Negative frequency (conjugate)\n",
- " spectrum[2] = 10 # Positive frequency\n",
- " spectrum[-2] = 10 # Negative frequency (conjugate)\n",
- " spectrum[3] = 7.5 # Second frequency\n",
- " spectrum[-3] = 7.5 # Its conjugate\n",
- " spectrum[4] = 5 # Second frequency\n",
- " spectrum[-4] = 5 # Its conjugate\n",
- " spectrum[5] = 2.5 # Third frequency \n",
- " spectrum[-5] = 2.5 # Its conjugate\n",
- " \n",
- " # Generate signal from spectrum\n",
- " template = np.fft.ifft(spectrum).real\n",
- "\n",
- " return template\n",
- "\n",
- "def ModularAdditionDataset(p, template):\n",
- " # Initialize data arrays\n",
- " X = np.zeros((p * p, 2, p)) # Shape: (p^2, 2, p)\n",
- " Y = np.zeros((p * p, p)) # Shape: (p^2, p)\n",
- " \n",
- " # Generate the dataset\n",
- " idx = 0\n",
- " for a in range(p):\n",
- " for b in range(p):\n",
- " q = (a + b) % p # a + b mod p\n",
- " X[idx, 0, :] = np.roll(template, a)\n",
- " X[idx, 1, :] = np.roll(template, b)\n",
- " Y[idx, :] = np.roll(template, q)\n",
- " idx += 1\n",
- " \n",
- " return X, Y"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063",
- "metadata": {},
- "source": [
- "## Architecture"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a",
- "metadata": {},
- "outputs": [],
- "source": [
- "class TwoLayerNet(nn.Module):\n",
- " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n",
- " super(TwoLayerNet, self).__init__()\n",
- " \n",
- " # Store dimensions\n",
- " self.p = p\n",
- " self.hidden_size = hidden_size\n",
- " self.nonlinearity = nonlinearity\n",
- " self.init_scale = init_scale\n",
- " self.output_scale = output_scale\n",
- " \n",
- " # Initialize parameters \n",
- " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n",
- " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n",
- " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n",
- "\n",
- " def forward(self, x):\n",
- " \n",
- " # First layer (linear and combined)\n",
- " x1 = x[:, :self.p] @ self.U.T\n",
- " x2 = x[:, self.p:] @ self.V.T\n",
- " x_combined = x1 + x2\n",
- "\n",
- " # Apply nonlinearity activation\n",
- " if self.nonlinearity == 'relu':\n",
- " x_combined = torch.relu(x_combined)\n",
- " elif self.nonlinearity == 'square':\n",
- " x_combined = x_combined**2\n",
- " elif self.nonlinearity == 'linear':\n",
- " x_combined = x_combined\n",
- " elif self.nonlinearity == 'tanh':\n",
- " x_combined = torch.tanh(x_combined)\n",
- " elif self.nonlinearity == 'gelu':\n",
- " gelu = torch.nn.GELU()\n",
- " x_combined = gelu(x_combined)\n",
- " else:\n",
- " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n",
- "\n",
- " # Second layer (linear)\n",
- " x_out = x_combined @ self.W\n",
- "\n",
- " # Feature learning scaling\n",
- " x_out *= self.output_scale\n",
- " \n",
- " return x_out"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168",
- "metadata": {},
- "source": [
- "## Optimization"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bc9ba87b-9607-4a4a-9b00-00c15adb2f5a",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "from torch import nn\n",
- "\n",
- "class PerNeuronScaledSGD(torch.optim.Optimizer):\n",
- " \"\"\"\n",
- " Per-neuron scaled SGD optimizer that exploits model homogeneity.\n",
- " \n",
- " Learning rate scaling per neuron i:\n",
- " eta_i = lr * ||theta_i||^(2-degree)\n",
- " \n",
- " where:\n",
- " - theta_i comprises all parameters associated with neuron i\n",
- " - degree is the degree of homogeneity of the model\n",
- "\n",
- " For SequentialMLP with sequence length k:\n",
- " - theta_i = (W_in[i, :], W_out[:, i])\n",
- " - degree = k+1 (activation is x^k, one more layer for W_out)\n",
- "\n",
- " For TwoLayerNet:\n",
- " - theta_i = (U[i, :], V[i, :], W[i, :])\n",
- " - degree:\n",
- " * nonlinearity == 'square' -> 3\n",
- " * otherwise default 2 (unless explicitly provided)\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, \n",
- " model, \n",
- " lr=1.0, \n",
- " degree=None\n",
- " ) -> None:\n",
- " \"\"\"\n",
- " Args:\n",
- " model: SequentialMLP, TwoLayerNet, or compatible model\n",
- " lr: base learning rate\n",
- " degree: degree of homogeneity (exponent for norm-based scaling)\n",
- " If None, inferred from model:\n",
- " - SequentialMLP: uses k+1 where k is sequence length\n",
- " - TwoLayerNet:\n",
- " * 'square' -> 3\n",
- " * otherwise -> 2\n",
- " - Default: 2\n",
- " \"\"\"\n",
- " model_type = type(model).__name__\n",
- "\n",
- " # Infer degree of homogeneity from model if not provided\n",
- " if degree is None:\n",
- " if hasattr(model, 'k'): # SequentialMLP-style\n",
- " degree = model.k + 1\n",
- " elif model_type == 'TwoLayerNet':\n",
- " nl = getattr(model, 'nonlinearity', None)\n",
- " if nl == 'square':\n",
- " degree = 3\n",
- " else:\n",
- " # For relu/linear/tanh/gelu or unknown, fall back\n",
- " degree = 2\n",
- " else:\n",
- " # Default for quadratic-ish models\n",
- " degree = 2\n",
- " \n",
- " # Get model parameters\n",
- " params = list(model.parameters())\n",
- " \n",
- " super().__init__(\n",
- " [{'params': params, 'model': model, 'model_type': model_type}], \n",
- " dict(lr=lr, degree=degree)\n",
- " )\n",
- "\n",
- " @torch.no_grad()\n",
- " def step(self, closure=None):\n",
- " group = self.param_groups[0]\n",
- " model = group['model']\n",
- " lr = group['lr']\n",
- " degree = group['degree']\n",
- " model_type = group['model_type']\n",
- " \n",
- " if model_type == 'SequentialMLP':\n",
- " # SequentialMLP: W_in (d, k*p), W_out (p, d)\n",
- " W_in = model.W_in\n",
- " W_out = model.W_out\n",
- " g_in = W_in.grad\n",
- " g_out = W_out.grad\n",
- " \n",
- " if g_in is None or g_out is None:\n",
- " return\n",
- " \n",
- " # Per-neuron norms: theta_i = (W_in[i, :], W_out[:, i])\n",
- " u2 = (W_in**2).sum(dim=1) # (d,)\n",
- " w2 = (W_out**2).sum(dim=0) # (d,)\n",
- " theta_norm = torch.sqrt(u2 + w2 + 1e-12) # (d,)\n",
- " \n",
- " # Scale = ||theta_i||^(2-degree)\n",
- " scale = theta_norm.pow(2 - degree)\n",
- " \n",
- " # Scale each neuron's gradients\n",
- " g_in.mul_(scale.view(-1, 1))\n",
- " g_out.mul_(scale.view(1, -1))\n",
- " \n",
- " # SGD update\n",
- " W_in.add_(g_in, alpha=-lr)\n",
- " W_out.add_(g_out, alpha=-lr)\n",
- "\n",
- " elif model_type == 'TwoLayerNet':\n",
- " # TwoLayerNet: U (d, p), V (d, p), W (d, p)\n",
- " U = model.U\n",
- " V = model.V\n",
- " W = model.W\n",
- "\n",
- " g_U = U.grad\n",
- " g_V = V.grad\n",
- " g_W = W.grad\n",
- "\n",
- " if g_U is None or g_V is None or g_W is None:\n",
- " return\n",
- "\n",
- " # Per-neuron norms: theta_i = (U[i, :], V[i, :], W[i, :])\n",
- " u2 = (U**2).sum(dim=1) # (hidden_size,)\n",
- " v2 = (V**2).sum(dim=1) # (hidden_size,)\n",
- " w2 = (W**2).sum(dim=1) # (hidden_size,)\n",
- " theta_norm = torch.sqrt(u2 + v2 + w2 + 1e-12) # (hidden_size,)\n",
- "\n",
- " # Scale = ||theta_i||^(2-degree)\n",
- " scale = theta_norm.pow(2 - degree) # (hidden_size,)\n",
- "\n",
- " # Scale each neuron's gradients\n",
- " scale_view = scale.view(-1, 1) # (hidden_size, 1)\n",
- " g_U.mul_(scale_view)\n",
- " g_V.mul_(scale_view)\n",
- " g_W.mul_(scale_view)\n",
- "\n",
- " # SGD update\n",
- " U.add_(g_U, alpha=-lr)\n",
- " V.add_(g_V, alpha=-lr)\n",
- " W.add_(g_W, alpha=-lr)\n",
- "\n",
- " else:\n",
- " raise ValueError(f\"PerNeuronScaledSGD: Unsupported model structure with {model_type}\")\n",
- " \n",
- " return None\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1035f81c-e877-4655-8640-4e4c3d323af8",
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_accuracy(model, dataloader):\n",
- " correct = 0\n",
- " total = 0\n",
- " \n",
- " with torch.no_grad(): # Disable gradient calculation for evaluation\n",
- " for inputs, labels in dataloader:\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- " outputs = model(inputs)\n",
- " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n",
- " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n",
- " correct += (predicted == true_labels).sum().item()\n",
- " total += labels.size(0)\n",
- " \n",
- " accuracy = 100 * correct / total\n",
- " return accuracy\n",
- "\n",
- "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n",
- " model.train() # Set the model to training mode\n",
- " loss_history = [] # List to store loss values\n",
- " accuracy_history = []\n",
- " param_history = []\n",
- "\n",
- " for epoch in range(epochs):\n",
- " running_loss = 0.0\n",
- " for inputs, labels in dataloader:\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- "\n",
- " optimizer.zero_grad() # Zero gradients\n",
- " outputs = model(inputs) # Forward pass\n",
- " loss = criterion(outputs, labels) # Compute loss\n",
- " loss.backward() # Backpropagation\n",
- " optimizer.step() # Update weights\n",
- "\n",
- " running_loss += loss.item()\n",
- "\n",
- " # Append the average loss for the epoch to loss_history\n",
- " avg_loss = running_loss / len(dataloader)\n",
- " loss_history.append(avg_loss)\n",
- "\n",
- " # Append the accuracy\n",
- " model.eval()\n",
- " accuracy = test_accuracy(model, dataloader)\n",
- " accuracy_history.append(accuracy)\n",
- " model.train()\n",
- "\n",
- " # Save current model parameters\n",
- " current_params = {\n",
- " \"U\": model.U.detach().cpu().clone(),\n",
- " \"V\": model.V.detach().cpu().clone(),\n",
- " \"W\": model.W.detach().cpu().clone()\n",
- " }\n",
- " param_history.append(current_params)\n",
- "\n",
- " # Print verbose information every `verbose_interval` epochs\n",
- " if (epoch + 1) % verbose_interval == 0:\n",
- " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n",
- "\n",
- " return loss_history, accuracy_history, param_history # Return loss history for plotting"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c",
- "metadata": {},
- "source": [
- "## Plotting functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3",
- "metadata": {},
- "outputs": [],
- "source": [
- "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n",
- " # Y-axis ticks\n",
- " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n",
- " \n",
- " # X-axis ticks\n",
- " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n",
- "\n",
- " # Scientific notation formatting\n",
- " if ax.get_yscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n",
- " if ax.get_xscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n",
- "\n",
- " ax.xaxis.offsetText.set_fontsize(20)\n",
- " ax.grid()\n",
- "\n",
- " # Customize spines\n",
- " for spine in [\"top\", \"right\"]:\n",
- " ax.spines[spine].set_visible(False)\n",
- " for spine in [\"left\", \"bottom\"]:\n",
- " ax.spines[spine].set_linewidth(3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a",
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_power(points):\n",
- " p = len(points)\n",
- " num_coefficients = (p // 2) + 1\n",
- " \n",
- " # Perform FFT and calculate power spectrum\n",
- " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n",
- " power = np.abs(ft[:num_coefficients])**2 / p\n",
- " \n",
- " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n",
- " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n",
- " power[1:num_coefficients - 1] *= 2\n",
- " else: # p is odd, no Nyquist frequency\n",
- " power[1:] *= 2\n",
- "\n",
- " # Confirm the power sum approximates the squared norm of points\n",
- " total_power = np.sum(power)\n",
- " norm_squared = np.linalg.norm(points)**2\n",
- " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n",
- " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n",
- "\n",
- " return np.arange(num_coefficients), power\n",
- "\n",
- "def interpolate(ax, points, color, continuous, alpha=1.0):\n",
- " p = len(points)\n",
- " if continuous:\n",
- " # Perform Fourier Transform\n",
- " ft = np.fft.fft(points)\n",
- " \n",
- " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n",
- " num_coefficients = (p // 2) + 1\n",
- " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n",
- " \n",
- " # Create a dense set of x-values for smooth interpolation\n",
- " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n",
- " curr_val = np.zeros(xs.shape, dtype=complex)\n",
- " \n",
- " # Use only non-negative frequencies for interpolation\n",
- " for freq in range(num_coefficients):\n",
- " theta = np.angle(ft[freq])\n",
- " r = np.abs(ft[freq]) / p\n",
- " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n",
- " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n",
- " r *= 2\n",
- " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n",
- "\n",
- " # Plot the real part (since output is real-valued)\n",
- " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n",
- " else:\n",
- " ax.plot(np.arange(p), points, color=color, alpha=alpha) "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7",
- "metadata": {},
- "source": [
- "## Gradient Descent Experiment"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bcd15c5a-5745-41ba-b015-48e403160c7e",
- "metadata": {},
- "outputs": [],
- "source": [
- "seed = 0 # or any integer you like\n",
- "random.seed(seed)\n",
- "np.random.seed(seed)\n",
- "torch.manual_seed(seed)\n",
- "torch.cuda.manual_seed_all(seed) # if using GPU\n",
- "\n",
- "# TEST_MODE: Reduce p and hidden_size for faster automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "\n",
- "# Data Generation using the new function\n",
- "p = 11 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n",
- "\n",
- "# Get base vector\n",
- "# template = one_hot(p)\n",
- "template = generate_fixed_template(p)\n",
- "\n",
- "# Mean center template\n",
- "template -= np.mean(template)\n",
- "\n",
- "# Generate dataset using numpy\n",
- "X, Y = ModularAdditionDataset(p, template)\n",
- "\n",
- "# Convert to PyTorch tensors\n",
- "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n",
- "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n",
- "\n",
- "# Create a TensorDataset and DataLoader\n",
- "dataset = TensorDataset(X_tensor, Y_tensor)\n",
- "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n",
- "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n",
- "\n",
- "# Initialize model\n",
- "hidden_size = 20 if TEST_MODE else 200 # Reduced in test mode\n",
- "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-5, output_scale=1e0)\n",
- "\n",
- "# Create loss function\n",
- "loss = nn.MSELoss()\n",
- "\n",
- "# Construct optimizer\n",
- "lr = 0.01\n",
- "optimizer = PerNeuronScaledSGD(model, lr=lr, degree=3) # explicit\n",
- "\n",
- "# Train the model\n",
- "# TEST_MODE: Set to reduce epochs for automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "epochs = 2 if TEST_MODE else 50001\n",
- "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//10))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0f48aebc-a439-405a-a057-3f5c24cca91a",
- "metadata": {},
- "source": [
- "## Plot Loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff46febe-abb5-459a-bb06-a18a26afb967",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n",
- "ax.plot(list(loss_history), lw=7)\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "for k, alpha in enumerate(alpha_values):\n",
- " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n",
- "\n",
- "ax.set_xscale(\"log\")\n",
- "ax.set_yscale(\"log\")\n",
- "ax.set_ylim(1e-2, 10)\n",
- "ax.set_xlabel('Epochs', fontsize=24)\n",
- "ax.set_ylabel('Train Loss', fontsize=24)\n",
- "\n",
- "style_axes(ax)\n",
- "plt.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n",
- "plt.savefig(\"loss-without-lines.svg\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "40b851e7-6256-43cd-b9f3-aca38db04917",
- "metadata": {},
- "source": [
- "## Power Spectrum of output"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "68b25ca9-6339-49dd-9d45-577a51798a25",
- "metadata": {},
- "outputs": [],
- "source": [
- "# === SETTINGS ===\n",
- "p = Y_tensor.shape[1]\n",
- "num_freqs = p // 2 + 1\n",
- "mom = 0.9\n",
- "\n",
- "# Compute template power spectrum\n",
- "template_ft = np.fft.rfft(template)\n",
- "template_power = np.abs(template_ft)[:num_freqs]\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "# === Theory lines ===\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "theta0 = np.sqrt(2) * model.init_scale\n",
- "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n",
- "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n",
- "step_size = 2 * coef * lr / (1 - mom)\n",
- "\n",
- "\n",
- "# Color settings\n",
- "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n",
- "manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- "}\n",
- "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n",
- "\n",
- "# Compute output power over time (GD)\n",
- "num_points = 1000\n",
- "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n",
- "powers_over_time = []\n",
- "\n",
- "for step in steps:\n",
- " model.load_state_dict(param_history[step])\n",
- " model.eval()\n",
- " with torch.no_grad():\n",
- " outputs = model(X_tensor)\n",
- " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
- " avg_power = np.mean(np.abs(ft), axis=0)\n",
- " powers_over_time.append(avg_power)\n",
- "\n",
- "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n",
- "\n",
- "# === PLOTTING ===\n",
- "fig, ax = plt.subplots(figsize=(8, 6))\n",
- "\n",
- "for k in range(num_freqs):\n",
- " color = colors[k]\n",
- " label = fr\"$\\xi = {k}$\"# if k in [1, 3, 5] else None\n",
- " ax.plot(steps, powers_over_time[:, k], color=color, lw=5, label=label)\n",
- " label_agf = 'AGF' if k == 10 else None\n",
- " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n",
- "\n",
- "\n",
- "# Labeling and formatting\n",
- "ax.set_xscale('log')\n",
- "ax.set_ylabel(\"Power\", fontsize=24)\n",
- "ax.set_xlabel(\"Epochs\", fontsize=24)\n",
- "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper left', labelspacing=0.25)\n",
- "style_axes(ax)\n",
- "ax.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n",
- "plt.savefig(\"fourier_power_only.svg\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd",
- "metadata": {},
- "source": [
- "## Plot outputs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e333d1ab-1501-434f-86d2-82c10bb58f11",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import torch\n",
- "\n",
- "# -------------------------------------------------------------------\n",
- "# Choose the template (example: first row of Y_tensor)\n",
- "# -------------------------------------------------------------------\n",
- "template = Y_tensor[0].detach().cpu().numpy() # shape (p,)\n",
- "p = template.shape[0]\n",
- "x = np.arange(p)\n",
- "\n",
- "# -------------------------------------------------------------------\n",
- "# Figure 1: Template as black bar plot\n",
- "# -------------------------------------------------------------------\n",
- "fig1, ax1 = plt.subplots(figsize=(8, 6))\n",
- "ax1.bar(x, template, color=\"black\")\n",
- "\n",
- "ax1.set_xlabel(\"Index\", fontsize=14)\n",
- "ax1.set_ylabel(\"Template value\", fontsize=14)\n",
- "style_axes(ax1)\n",
- "ax1.grid(False)\n",
- "\n",
- "plt.tight_layout()\n",
- "fig1.savefig(\"template_bar.pdf\", bbox_inches=\"tight\")\n",
- "\n",
- "# -------------------------------------------------------------------\n",
- "# Figure 2: Fourier magnitude bar plot with conjugate-pair coloring\n",
- "# -------------------------------------------------------------------\n",
- "# Compute Fourier transform and magnitude\n",
- "fft_template = np.fft.fft(template)\n",
- "fft_mag = np.abs(fft_template)\n",
- "freqs = np.arange(p)\n",
- "\n",
- "# Number of *frequency groups* accounting for conjugate symmetry:\n",
- "# groups: 0, 1, ..., floor(p/2)\n",
- "num_groups = p // 2 + 1\n",
- "\n",
- "# Color settings\n",
- "cmap = plt.colormaps.get_cmap('tab20').resampled(num_groups)\n",
- "manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- "}\n",
- "group_colors = [manual_colors.get(i, cmap(i)) for i in range(num_groups)]\n",
- "\n",
- "# Assign each k a color based on its conjugate-symmetry group\n",
- "bar_colors = []\n",
- "for k in range(p):\n",
- " # group index: k and p-k share the same group\n",
- " g = k if k <= p // 2 else p - k\n",
- " bar_colors.append(group_colors[g])\n",
- "\n",
- "fig2, ax2 = plt.subplots(figsize=(8, 6))\n",
- "ax2.bar(freqs, fft_mag, color=bar_colors)\n",
- "\n",
- "ax2.set_xlabel(\"Frequency index $k$\", fontsize=14)\n",
- "ax2.set_ylabel(r\"$|\\hat{t}[k]|$\", fontsize=14)\n",
- "style_axes(ax2)\n",
- "ax2.grid(False)\n",
- "\n",
- "plt.tight_layout()\n",
- "fig2.savefig(\"template_fft_bar.pdf\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4ecb1adb-66ed-44a1-9b6a-ef5dcc6dbe11",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/cn.ipynb b/notebooks/cn.ipynb
new file mode 100644
index 0000000..fa3bf65
--- /dev/null
+++ b/notebooks/cn.ipynb
@@ -0,0 +1,538 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Binary Group Composition on $C_n$\n",
+ "\n",
+ "**Group:** Cyclic group $C_n$ of order $p$ (i.e., modular addition mod $p$). \n",
+ "**Task:** Given encodings of two group elements $a, b \\in C_p$, predict the encoding of their product $a + b \\pmod{p}$. \n",
+ "**Sequence length:** $k = 2$ (binary composition). \n",
+ "**Architecture:** `TwoLayerNet` with square nonlinearity. \n",
+ "**Key result:** The network learns one Fourier mode at a time, producing a staircase in the training loss."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "import src.dataset as dataset\n",
+ "import src.model as model\n",
+ "import src.optimizer as optimizer\n",
+ "import src.power as power\n",
+ "import src.template as template\n",
+ "import src.train as train_mod\n",
+ "import src.viz as viz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
+ "\n",
+ "seed = 0\n",
+ "random.seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "torch.manual_seed(seed)\n",
+ "\n",
+ "p = 11\n",
+ "hidden_size = 20 if TEST_MODE else 200\n",
+ "epochs = 2 if TEST_MODE else 5000\n",
+ "lr = 0.01\n",
+ "init_scale = 1e-5\n",
+ "\n",
+ "FIGURES_DIR = \"figures\"\n",
+ "os.makedirs(FIGURES_DIR, exist_ok=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Template and Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "067e5bc3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a fixed Cn template with known Fourier structure\n",
+ "tpl = template.fixed_cn(\n",
+ " group_size=p,\n",
+ " fourier_coef_mags=[0, 12.5, 10, 7.5, 5, 2.5],\n",
+ ")\n",
+ "\n",
+ "# Mean-center the template\n",
+ "tpl = tpl - np.mean(tpl)\n",
+ "\n",
+ "# Build exhaustive dataset: all p^2 pairs\n",
+ "X, Y = dataset.cn_dataset(tpl)\n",
+ "\n",
+ "# Move to tensors and flatten\n",
+ "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n",
+ "\n",
+ "ds = TensorDataset(X_tensor, Y_tensor)\n",
+ "dataloader = DataLoader(ds, batch_size=len(ds), shuffle=False)\n",
+ "\n",
+ "print(f\"Group: C_{p}, order {p}\")\n",
+ "print(f\"Dataset: {len(ds)} samples (all {p}x{p} pairs)\")\n",
+ "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a26b21e6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize template and its power spectrum\n",
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
+ "\n",
+ "ax1.bar(range(p), tpl, color=\"black\")\n",
+ "ax1.set_xlabel(\"Group element\")\n",
+ "ax1.set_ylabel(\"Template value\")\n",
+ "ax1.set_title(f\"Template $t$ on $C_{{{p}}}$\")\n",
+ "\n",
+ "pwr, freqs = power.get_power_1d(tpl)\n",
+ "ax2.bar(freqs, pwr, color=\"steelblue\")\n",
+ "ax2.set_xlabel(\"Frequency\")\n",
+ "ax2.set_ylabel(\"Power\")\n",
+ "ax2.set_title(\"Power spectrum of template\")\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cn_template.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model and Optimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "230b324e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = model.TwoLayerNet(\n",
+ " group_size=p,\n",
+ " hidden_size=hidden_size,\n",
+ " nonlinearity=\"square\",\n",
+ " init_scale=init_scale,\n",
+ ")\n",
+ "net = net.to(device)\n",
+ "\n",
+ "criterion = nn.MSELoss()\n",
+ "opt = optimizer.PerNeuronScaledSGD(net, lr=lr, degree=3)\n",
+ "\n",
+ "print(f\"Model: TwoLayerNet(p={p}, hidden={hidden_size}, init_scale={init_scale})\")\n",
+ "print(f\"Optimizer: PerNeuronScaledSGD(lr={lr}, degree=3)\")\n",
+ "print(f\"Training for {epochs} epochs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n",
+ " net,\n",
+ " dataloader,\n",
+ " criterion,\n",
+ " opt,\n",
+ " epochs=epochs,\n",
+ " verbose_interval=max(1, epochs // 10),\n",
+ " save_param_interval=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training Loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compute theoretical loss plateau levels\n",
+ "theory = power.theoretical_loss_levels_1d(tpl)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "ax.plot(loss_history, lw=4)\n",
+ "\n",
+ "for level in theory[\"levels\"]:\n",
+ " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_yscale(\"log\")\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_ylabel(\"Train Loss\", fontsize=18)\n",
+ "ax.set_title(f\"Training loss on $C_{{{p}}}$\", fontsize=20)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cn_loss.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Power Spectrum Over Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compute template power for reference lines\n",
+ "num_freqs = p // 2 + 1\n",
+ "template_ft = np.fft.rfft(tpl)\n",
+ "template_power = np.abs(template_ft)[:num_freqs]\n",
+ "\n",
+ "# Compute output power over time\n",
+ "num_points = min(500, len(param_history))\n",
+ "steps = np.unique(np.logspace(0, np.log10(max(1, len(param_history) - 1)), num_points, dtype=int))\n",
+ "powers_over_time = []\n",
+ "\n",
+ "for step in steps:\n",
+ " net.load_state_dict(param_history[step])\n",
+ " net.eval()\n",
+ " with torch.no_grad():\n",
+ " outputs = net(X_tensor)\n",
+ " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
+ " avg_power = np.mean(np.abs(ft), axis=0)\n",
+ " powers_over_time.append(avg_power)\n",
+ "\n",
+ "powers_over_time = np.array(powers_over_time)\n",
+ "\n",
+ "# Plot\n",
+ "colors = [\"tab:blue\", \"tab:orange\", \"tab:red\", \"tab:green\", \"tab:brown\", \"tab:purple\"]\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "\n",
+ "for k in range(num_freqs):\n",
+ " color = colors[k] if k < len(colors) else f\"C{k}\"\n",
+ " ax.plot(steps, powers_over_time[:, k], color=color, lw=4, label=rf\"$\\xi = {k}$\")\n",
+ " ax.axhline(template_power[k], color=color, linestyle=\"dotted\", linewidth=2, alpha=0.5, zorder=-10)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_ylabel(\"Power\", fontsize=18)\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_title(f\"Power spectrum over training on $C_{{{p}}}$\", fontsize=20)\n",
+ "ax.legend(fontsize=12, title=\"Frequency\", title_fontsize=14, loc=\"upper left\", labelspacing=0.25)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cn_power_spectrum.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## AGF Numerics\n",
+ "\n",
+ "Compare gradient descent training with the Alternating Gradient Flow (AGF) approximation.\n",
+ "AGF decomposes training into alternating utility-maximization and cost-minimization phases,\n",
+ "predicting when each Fourier mode activates."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tqdm import tqdm\n",
+ "\n",
+ "\n",
+ "class ModsumSubNetwork(nn.Module):\n",
+ " \"\"\"A single neuron of the two-layer network for AGF analysis.\"\"\"\n",
+ "\n",
+ " def __init__(self, d_in, d_out, init_scale):\n",
+ " super().__init__()\n",
+ " assert d_in % 2 == 0\n",
+ " self.p = d_in // 2\n",
+ " self.u = nn.Linear(self.p, 1, bias=False)\n",
+ " self.v = nn.Linear(self.p, 1, bias=False)\n",
+ " self.w = nn.Linear(1, d_out, bias=False)\n",
+ " with torch.no_grad():\n",
+ " self.w.weight.mul_(init_scale)\n",
+ " self.u.weight.mul_(init_scale)\n",
+ " self.v.weight.mul_(init_scale)\n",
+ " self.active = False\n",
+ " self.util_acc = 0\n",
+ " self.c_a = 1 / self.get_norm() - 1\n",
+ " self.normalize()\n",
+ "\n",
+ " def get_norm(self):\n",
+ " sqnorm = lambda x: torch.linalg.norm(x.weight) ** 2\n",
+ " return torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n",
+ "\n",
+ " def reinitialize(self, u, v, w):\n",
+ " with torch.no_grad():\n",
+ " self.u.weight.copy_(u)\n",
+ " self.v.weight.copy_(v)\n",
+ " self.w.weight.copy_(w)\n",
+ " self.c_a = 1 / self.get_norm() - 1\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x1 = x[:, : self.p]\n",
+ " x2 = x[:, self.p :]\n",
+ " return self.w((self.u(x1) + self.v(x2)) ** 2)\n",
+ "\n",
+ " def normalize(self):\n",
+ " norm = self.get_norm()\n",
+ " with torch.no_grad():\n",
+ " self.w.weight.div_(norm)\n",
+ " self.u.weight.div_(norm)\n",
+ " self.v.weight.div_(norm)\n",
+ "\n",
+ " def utility_step(self, x, residual, learning_rate):\n",
+ " f_i = self(x)\n",
+ " util = torch.einsum(\"nd,nd->n\", f_i, residual).mean()\n",
+ " self.util_acc += 3 * learning_rate * util.item()\n",
+ " norm_th = 1 / (1 + self.c_a - self.util_acc)\n",
+ " util.backward()\n",
+ " with torch.no_grad():\n",
+ " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n",
+ " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n",
+ " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n",
+ " self.w.weight.grad.zero_()\n",
+ " self.u.weight.grad.zero_()\n",
+ " self.v.weight.grad.zero_()\n",
+ " self.normalize()\n",
+ "\n",
+ "\n",
+ "class ModsumNetwork(nn.Module):\n",
+ " \"\"\"Network of ModsumSubNetwork neurons for AGF simulation.\"\"\"\n",
+ "\n",
+ " def __init__(self, d_in, d_out, init_scale, width=100):\n",
+ " super().__init__()\n",
+ " self.d_in = d_in\n",
+ " self.d_out = d_out\n",
+ " self.width = width\n",
+ " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n",
+ " self.neurons = nn.ModuleList(neurons)\n",
+ " self.set_mode(\"utilmax\")\n",
+ "\n",
+ " def load_init(self, U, V, W):\n",
+ " for i, n in enumerate(self.neurons):\n",
+ " u, v, w = U[i], V[i], W[i][:, None]\n",
+ " n.reinitialize(u, v, w)\n",
+ "\n",
+ " def dormant(self):\n",
+ " return [neuron for neuron in self.neurons if not neuron.active]\n",
+ "\n",
+ " def set_mode(self, mode):\n",
+ " if mode not in [\"utilmax\", \"costmin\"]:\n",
+ " raise ValueError(\"mode must be utilmax or costmin\")\n",
+ " self.mode = mode\n",
+ " for neuron in self.neurons:\n",
+ " grad_on = (mode == \"utilmax\") ^ neuron.active\n",
+ " for param in neuron.parameters():\n",
+ " param.requires_grad = grad_on\n",
+ "\n",
+ " def forward(self, x):\n",
+ " active = [n for n in self.neurons if n.active]\n",
+ " if not active:\n",
+ " return torch.zeros(x.shape[0], self.d_out)\n",
+ " outputs = torch.stack([n(x) for n in active], dim=0)\n",
+ " return torch.sum(outputs, dim=0)\n",
+ "\n",
+ "\n",
+ "def train_agf(\n",
+ " X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None,\n",
+ " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4,\n",
+ "):\n",
+ " \"\"\"Run the Alternating Gradient Flow (AGF) approximation.\"\"\"\n",
+ " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n",
+ " if from_init:\n",
+ " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n",
+ " width = U.shape[0]\n",
+ " agf_net = ModsumNetwork(d_in, d_out, init_sz, width=width)\n",
+ " agf_net.load_init(U, V, W)\n",
+ " else:\n",
+ " agf_net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)\n",
+ " X_train.requires_grad = False\n",
+ "\n",
+ " results = {\"t\": [], \"losses\": [], \"pred\": []}\n",
+ "\n",
+ " def update_results(t):\n",
+ " results[\"t\"].append(t)\n",
+ " residual = (Y_train - agf_net(X_train)).detach()\n",
+ " results[\"losses\"].append((residual**2).mean().item())\n",
+ " results[\"pred\"].append(agf_net(X_train).detach().cpu().clone())\n",
+ "\n",
+ " t = 0\n",
+ " update_results(t)\n",
+ " for _ in tqdm(range(agf_steps)):\n",
+ " residual = (1 / d_out) * 2 * (Y_train - agf_net(X_train))\n",
+ " residual = residual.detach()\n",
+ " iters = 0\n",
+ " mode = \"utilmax\"\n",
+ " while mode == \"utilmax\":\n",
+ " for n in agf_net.neurons:\n",
+ " if n.active:\n",
+ " continue\n",
+ " n.utility_step(X_train, residual, utilmax_lr)\n",
+ " if n.util_acc > n.c_a:\n",
+ " n.active = True\n",
+ " mode = \"costmin\"\n",
+ " iters += 1\n",
+ " agf_net.set_mode(mode)\n",
+ " t += iters\n",
+ "\n",
+ " agf_opt = torch.optim.SGD(agf_net.parameters(), lr=costmin_lr, momentum=0.9)\n",
+ " for _ in range(int(costmin_maxiter)):\n",
+ " agf_opt.zero_grad(set_to_none=False)\n",
+ " residual = Y_train - agf_net(X_train)\n",
+ " loss = (residual**2).mean()\n",
+ " loss.backward()\n",
+ " agf_opt.step()\n",
+ " agf_net.set_mode(\"utilmax\")\n",
+ "\n",
+ " print(f\"loss: {loss.item():.5f}\")\n",
+ " update_results(t)\n",
+ "\n",
+ " if not agf_net.dormant() or loss.item() < loss_thresh:\n",
+ " break\n",
+ "\n",
+ " return results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not TEST_MODE:\n",
+ " agf_results = train_agf(\n",
+ " X_tensor, Y_tensor,\n",
+ " init_sz=init_scale,\n",
+ " agf_steps=50,\n",
+ " from_init=param_history[0],\n",
+ " utilmax_lr=0.1,\n",
+ " costmin_lr=0.01,\n",
+ " costmin_maxiter=1e4,\n",
+ " loss_thresh=1e-4,\n",
+ " )\n",
+ "else:\n",
+ " agf_results = None\n",
+ " print(\"Skipping AGF in TEST_MODE\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loss: GD vs AGF"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "ax.plot(loss_history, lw=4, label=\"GD\")\n",
+ "\n",
+ "# Theory plateau levels\n",
+ "for level in theory[\"levels\"]:\n",
+ " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
+ "\n",
+ "# AGF overlay\n",
+ "if agf_results is not None:\n",
+ " utilmax_lr_val = 0.1\n",
+ " f = utilmax_lr_val / lr\n",
+ " agf_times = agf_results[\"t\"] + [epochs]\n",
+ " agf_losses = agf_results[\"losses\"] + [agf_results[\"losses\"][-1]]\n",
+ " ax.step(f * np.array(agf_times), agf_losses, where=\"post\", lw=2, ls=\"dashed\", color=\"k\", label=\"AGF\")\n",
+ " ax.legend(fontsize=14)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_yscale(\"log\")\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_ylabel(\"Train Loss\", fontsize=18)\n",
+ "ax.set_title(f\"GD vs AGF on $C_{{{p}}}$\", fontsize=20)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cn_loss_agf.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "group-agf",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/cnxcn.ipynb b/notebooks/cnxcn.ipynb
new file mode 100644
index 0000000..3307c24
--- /dev/null
+++ b/notebooks/cnxcn.ipynb
@@ -0,0 +1,270 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Binary Group Composition on $C_n \\times C_n$\n",
+ "\n",
+ "**Group:** Product of cyclic groups $C_n \\times C_n$ of order $n^2$. \n",
+ "**Task:** Given encodings of two group elements $g_1, g_2 \\in C_n \\times C_n$, predict the encoding of their product. \n",
+ "**Sequence length:** $k = 2$ (binary composition). \n",
+ "**Architecture:** `TwoLayerNet` with square nonlinearity. \n",
+ "**Key result:** The network learns one irreducible representation at a time."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "import src.dataset as dataset\n",
+ "import src.model as model\n",
+ "import src.power as power\n",
+ "import src.template as template\n",
+ "import src.train as train_mod\n",
+ "import src.viz as viz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
+ "\n",
+ "seed = 47\n",
+ "random.seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "torch.manual_seed(seed)\n",
+ "\n",
+ "n = 3 if TEST_MODE else 5\n",
+ "group_size = n * n\n",
+ "\n",
+ "hidden_size = 32 if TEST_MODE else 128\n",
+ "epochs = 2 if TEST_MODE else 1000\n",
+ "lr = 0.001\n",
+ "init_scale = 1e-2\n",
+ "batch_size = 32 if TEST_MODE else 128\n",
+ "\n",
+ "FIGURES_DIR = \"figures\"\n",
+ "os.makedirs(FIGURES_DIR, exist_ok=True)\n",
+ "\n",
+ "print(f\"Group: C_{n} x C_{n}, order {group_size}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Template and Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a fixed CnxCn template with known Fourier structure\n",
+ "fourier_coef_mags = np.random.RandomState(seed).rand(n) * 10\n",
+ "tpl = template.fixed_cnxcn(image_length=n, fourier_coef_mags=fourier_coef_mags)\n",
+ "\n",
+ "# Build exhaustive dataset: all group_size^2 pairs\n",
+ "X, Y = dataset.cnxcn_dataset(tpl)\n",
+ "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n",
+ "\n",
+ "ds = TensorDataset(X_tensor, Y_tensor)\n",
+ "dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)\n",
+ "\n",
+ "print(f\"Dataset: {len(ds)} samples\")\n",
+ "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize template as 2D image\n",
+ "tpl_2d = tpl.reshape(n, n) if tpl.ndim == 1 else tpl\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(5, 4))\n",
+ "im = ax.imshow(tpl_2d, cmap=\"RdBu_r\")\n",
+ "ax.set_xlabel(\"$C_n$ index\")\n",
+ "ax.set_ylabel(\"$C_n$ index\")\n",
+ "ax.set_title(f\"Template on $C_{{{n}}} \\\\times C_{{{n}}}$\")\n",
+ "plt.colorbar(im, ax=ax)\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cnxcn_template.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model and Optimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = model.TwoLayerNet(\n",
+ " group_size=group_size,\n",
+ " hidden_size=hidden_size,\n",
+ " nonlinearity=\"square\",\n",
+ " init_scale=init_scale,\n",
+ ")\n",
+ "net = net.to(device)\n",
+ "\n",
+ "criterion = nn.MSELoss()\n",
+ "opt = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))\n",
+ "\n",
+ "print(f\"Model: TwoLayerNet(group_size={group_size}, hidden={hidden_size})\")\n",
+ "print(f\"Optimizer: Adam(lr={lr})\")\n",
+ "print(f\"Training for {epochs} epochs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n",
+ " net,\n",
+ " dataloader,\n",
+ " criterion,\n",
+ " opt,\n",
+ " epochs=epochs,\n",
+ " verbose_interval=max(1, epochs // 10),\n",
+ " save_param_interval=max(1, epochs // 100),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training Loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compute theoretical loss plateau levels\n",
+ "theory = power.theoretical_loss_levels_2d(tpl_2d)\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "ax.plot(loss_history, lw=4)\n",
+ "\n",
+ "for level in theory[\"levels\"]:\n",
+ " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_yscale(\"log\")\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_ylabel(\"Train Loss\", fontsize=18)\n",
+ "ax.set_title(f\"Training loss on $C_{{{n}}} \\\\times C_{{{n}}}$\", fontsize=20)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cnxcn_loss.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Show a few predictions vs ground truth\n",
+ "net.load_state_dict(param_history[-1])\n",
+ "net.eval()\n",
+ "\n",
+ "n_examples = 3\n",
+ "indices = np.random.choice(len(Y_tensor), size=n_examples, replace=False)\n",
+ "\n",
+ "fig, axes = plt.subplots(n_examples, 2, figsize=(8, 3 * n_examples))\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " preds = net(X_tensor[indices]).detach().cpu().numpy()\n",
+ " truths = Y_tensor[indices].detach().cpu().numpy()\n",
+ "\n",
+ "for i in range(n_examples):\n",
+ " axes[i, 0].imshow(truths[i].reshape(n, n), cmap=\"RdBu_r\")\n",
+ " axes[i, 0].set_title(\"Ground truth\")\n",
+ " axes[i, 1].imshow(preds[i].reshape(n, n), cmap=\"RdBu_r\")\n",
+ " axes[i, 1].set_title(\"Prediction\")\n",
+ "\n",
+ "plt.suptitle(f\"Predictions on $C_{{{n}}} \\\\times C_{{{n}}}$ (epoch {final_epoch})\", fontsize=16)\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/cnxcn_predictions.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "group-agf",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/dihedral.ipynb b/notebooks/dihedral.ipynb
deleted file mode 100644
index a7bb7ca..0000000
--- a/notebooks/dihedral.ipynb
+++ /dev/null
@@ -1,1064 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "51d11caf-0971-4324-b63b-819b714a9c3c",
- "metadata": {},
- "source": [
- "# Diehdral 1D"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from tqdm import tqdm\n",
- "import torch.optim as optim\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "import seaborn as sns\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "from matplotlib.animation import FuncAnimation\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "from matplotlib.ticker import FuncFormatter\n",
- "from matplotlib.ticker import MaxNLocator"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6",
- "metadata": {},
- "source": [
- "## Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88",
- "metadata": {},
- "outputs": [],
- "source": [
- "def one_hot(p):\n",
- " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n",
- " vec = np.zeros(p)\n",
- " vec[0] = 1\n",
- " return vec\n",
- "\n",
- "def generate_template(p, magnitude, exponent):\n",
- " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n",
- " template = np.ones(p) # Base term (DC component)\n",
- " for freq in range(1, p):\n",
- " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n",
- " return template / p\n",
- "\n",
- "def generate_fixed_template(p):\n",
- " # Generate template array from Fourier spectrum\n",
- " spectrum = np.zeros(p, dtype=complex)\n",
- " \n",
- " # Set only three frequencies with specific amplitudes\n",
- " spectrum[1] = 10 # Positive frequency\n",
- " spectrum[-1] = 10 # Negative frequency (conjugate)\n",
- " spectrum[3] = 5 # Second frequency\n",
- " spectrum[-3] = 5 # Its conjugate\n",
- " spectrum[5] = 2.5 # Third frequency \n",
- " spectrum[-5] = 2.5 # Its conjugate\n",
- " \n",
- " # Generate signal from spectrum\n",
- " template = np.fft.ifft(spectrum).real\n",
- "\n",
- " return template\n",
- "\n",
- "def DihedralDataset(p, template):\n",
- " # Initialize data arrays\n",
- " X = np.zeros((4*p * p, 2, p)) # Shape: (p^2, 2, p)\n",
- " Y = np.zeros((4*p * p, p)) # Shape: (p^2, p)\n",
- " \n",
- " # Generate the dataset\n",
- " idx = 0\n",
- " for a in range(p):\n",
- " for b in range(p):\n",
- " for sa in (-1, 1):\n",
- " for sb in (-1, 1): \n",
- " X[idx, 0, :] = np.roll(template, a)[::sa] # reverse or does not reverse the order, depending on sa.\n",
- " X[idx, 1, :] = np.roll(template, b)[::sb]\n",
- " Y[idx, :] = np.roll(np.roll(template, a)[::sa], b)[::sb]\n",
- " idx += 1\n",
- " \n",
- " return X, Y"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063",
- "metadata": {},
- "source": [
- "## Architecture"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a",
- "metadata": {},
- "outputs": [],
- "source": [
- "class TwoLayerNet(nn.Module):\n",
- " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n",
- " super(TwoLayerNet, self).__init__()\n",
- " \n",
- " # Store dimensions\n",
- " self.p = p\n",
- " self.hidden_size = hidden_size\n",
- " self.nonlinearity = nonlinearity\n",
- " self.init_scale = init_scale\n",
- " self.output_scale = output_scale\n",
- " \n",
- " # Initialize parameters \n",
- " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n",
- " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n",
- " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n",
- " print(f\"Initialized U with shape {self.U.shape}\")\n",
- " print(f\"Initialized V with shape {self.V.shape}\")\n",
- " print(f\"Initialized W with shape {self.W.shape}\")\n",
- "\n",
- " def forward(self, x):\n",
- " print(f\"Input x shape: {x.shape}\")\n",
- " # First layer (linear and combined)\n",
- " x1 = x[:, :self.p] @ self.U.T\n",
- " print(f\"x1 (x @ U.T) shape: {x1.shape}\")\n",
- " x2 = x[:, self.p:] @ self.V.T\n",
- " print(f\"x2 (x @ V.T) shape: {x2.shape}\")\n",
- " x_combined = x1 + x2\n",
- " print(f\"x_combined (x1 + x2) shape: {x_combined.shape}\")\n",
- "\n",
- " # Apply nonlinearity activation\n",
- " if self.nonlinearity == 'relu':\n",
- " x_combined = torch.relu(x_combined)\n",
- " print(\"Applied ReLU nonlinearity\")\n",
- " elif self.nonlinearity == 'square':\n",
- " x_combined = x_combined**2\n",
- " print(\"Applied square nonlinearity\")\n",
- " elif self.nonlinearity == 'linear':\n",
- " x_combined = x_combined\n",
- " print(\"Applied linear (identity) nonlinearity\")\n",
- " elif self.nonlinearity == 'tanh':\n",
- " x_combined = torch.tanh(x_combined)\n",
- " print(\"Applied tanh nonlinearity\")\n",
- " elif self.nonlinearity == 'gelu':\n",
- " gelu = torch.nn.GELU()\n",
- " x_combined = gelu(x_combined)\n",
- " print(\"Applied GELU nonlinearity\")\n",
- " else:\n",
- " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n",
- "\n",
- " # Second layer (linear)\n",
- " x_out = x_combined @ self.W\n",
- " print(f\"x_out (x_combined @ W) shape: {x_out.shape}\")\n",
- "\n",
- " # Feature learning scaling\n",
- " x_out *= self.output_scale\n",
- " print(f\"x_out after scaling with output_scale={self.output_scale}: shape {x_out.shape}\")\n",
- " \n",
- " return x_out"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168",
- "metadata": {},
- "source": [
- "## Optimization"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1035f81c-e877-4655-8640-4e4c3d323af8",
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_accuracy(model, dataloader):\n",
- " correct = 0\n",
- " total = 0\n",
- " print(\"Starting test_accuracy evaluation...\")\n",
- " \n",
- " with torch.no_grad(): # Disable gradient calculation for evaluation\n",
- " for i, (inputs, labels) in enumerate(dataloader):\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- " print(f\"Batch {i+1}: inputs reshaped\")\n",
- " outputs = model(inputs)\n",
- " print(f\"Batch {i+1}: model forward pass done\")\n",
- " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n",
- " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n",
- " correct += (predicted == true_labels).sum().item()\n",
- " total += labels.size(0)\n",
- " print(f\"Batch {i+1}: accuracy updated (correct={correct}, total={total})\")\n",
- " \n",
- " accuracy = 100 * correct / total\n",
- " print(f\"Final test accuracy: {accuracy:.2f}%\")\n",
- " return accuracy\n",
- "\n",
- "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n",
- " print(\"Starting training loop...\")\n",
- " model.train() # Set the model to training mode\n",
- " print(\"Model set to train mode.\")\n",
- " loss_history = [] # List to store loss values\n",
- " accuracy_history = []\n",
- " param_history = []\n",
- "\n",
- " for epoch in range(epochs):\n",
- " print(f\"Epoch {epoch+1} started.\")\n",
- " running_loss = 0.0\n",
- " for batch_idx, (inputs, labels) in enumerate(dataloader):\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- " print(f\" Batch {batch_idx+1}: inputs reshaped\")\n",
- "\n",
- " optimizer.zero_grad() # Zero gradients\n",
- " print(f\" Batch {batch_idx+1}: optimizer gradients zeroed\")\n",
- " outputs = model(inputs) # Forward pass\n",
- " print(f\" Batch {batch_idx+1}: model forward pass done\")\n",
- " loss = criterion(outputs, labels) # Compute loss\n",
- " print(f\" Batch {batch_idx+1}: loss computed ({loss.item():.4f})\")\n",
- " loss.backward() # Backpropagation\n",
- " print(f\" Batch {batch_idx+1}: backward pass done\")\n",
- " optimizer.step() # Update weights\n",
- " print(f\" Batch {batch_idx+1}: optimizer step done\")\n",
- "\n",
- " running_loss += loss.item()\n",
- " print(f\" Batch {batch_idx+1}: running_loss updated ({running_loss:.4f})\")\n",
- "\n",
- " # Append the average loss for the epoch to loss_history\n",
- " avg_loss = running_loss / len(dataloader)\n",
- " loss_history.append(avg_loss)\n",
- " print(f\"Epoch {epoch+1}: avg_loss appended ({avg_loss:.4f})\")\n",
- "\n",
- " # Append the accuracy\n",
- " model.eval()\n",
- " print(f\"Epoch {epoch+1}: model set to eval mode for accuracy check\")\n",
- " accuracy = test_accuracy(model, dataloader)\n",
- " accuracy_history.append(accuracy)\n",
- " print(f\"Epoch {epoch+1}: accuracy appended ({accuracy:.2f}%)\")\n",
- " model.train()\n",
- " print(f\"Epoch {epoch+1}: model set back to train mode\")\n",
- "\n",
- " # Save current model parameters\n",
- " current_params = {\n",
- " \"U\": model.U.detach().cpu().clone(),\n",
- " \"V\": model.V.detach().cpu().clone(),\n",
- " \"W\": model.W.detach().cpu().clone()\n",
- " }\n",
- " param_history.append(current_params)\n",
- " print(f\"Epoch {epoch+1}: model parameters saved\")\n",
- "\n",
- " # Print verbose information every `verbose_interval` epochs\n",
- " if (epoch + 1) % verbose_interval == 0:\n",
- " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n",
- "\n",
- " print(\"Training loop finished.\")\n",
- " return loss_history, accuracy_history, param_history # Return loss history for plotting"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c",
- "metadata": {},
- "source": [
- "## Plotting functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3",
- "metadata": {},
- "outputs": [],
- "source": [
- "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n",
- " # Y-axis ticks\n",
- " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n",
- " \n",
- " # X-axis ticks\n",
- " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n",
- "\n",
- " # Scientific notation formatting\n",
- " if ax.get_yscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n",
- " if ax.get_xscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n",
- "\n",
- " ax.xaxis.offsetText.set_fontsize(20)\n",
- " ax.grid()\n",
- "\n",
- " # Customize spines\n",
- " for spine in [\"top\", \"right\"]:\n",
- " ax.spines[spine].set_visible(False)\n",
- " for spine in [\"left\", \"bottom\"]:\n",
- " ax.spines[spine].set_linewidth(3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a",
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_power(points):\n",
- " p = len(points)\n",
- " num_coefficients = (p // 2) + 1\n",
- " \n",
- " # Perform FFT and calculate power spectrum\n",
- " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n",
- " power = np.abs(ft[:num_coefficients])**2 / p\n",
- " \n",
- " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n",
- " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n",
- " power[1:num_coefficients - 1] *= 2\n",
- " else: # p is odd, no Nyquist frequency\n",
- " power[1:] *= 2\n",
- "\n",
- " # Confirm the power sum approximates the squared norm of points\n",
- " total_power = np.sum(power)\n",
- " norm_squared = np.linalg.norm(points)**2\n",
- " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n",
- " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n",
- "\n",
- " return np.arange(num_coefficients), power\n",
- "\n",
- "def interpolate(ax, points, color, continuous, alpha=1.0):\n",
- " p = len(points)\n",
- " if continuous:\n",
- " # Perform Fourier Transform\n",
- " ft = np.fft.fft(points)\n",
- " \n",
- " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n",
- " num_coefficients = (p // 2) + 1\n",
- " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n",
- " \n",
- " # Create a dense set of x-values for smooth interpolation\n",
- " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n",
- " curr_val = np.zeros(xs.shape, dtype=complex)\n",
- " \n",
- " # Use only non-negative frequencies for interpolation\n",
- " for freq in range(num_coefficients):\n",
- " theta = np.angle(ft[freq])\n",
- " r = np.abs(ft[freq]) / p\n",
- " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n",
- " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n",
- " r *= 2\n",
- " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n",
- "\n",
- " # Plot the real part (since output is real-valued)\n",
- " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n",
- " else:\n",
- " ax.plot(np.arange(p), points, color=color, alpha=alpha) "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7",
- "metadata": {},
- "source": [
- "## Gradient Descent Experiment"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bcd15c5a-5745-41ba-b015-48e403160c7e",
- "metadata": {},
- "outputs": [],
- "source": [
- "seed = 0 # or any integer you like\n",
- "random.seed(seed)\n",
- "np.random.seed(seed)\n",
- "torch.manual_seed(seed)\n",
- "torch.cuda.manual_seed_all(seed) # if using GPU\n",
- "\n",
- "# TEST_MODE: Reduce p, hidden_size and epochs for faster automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "\n",
- "# Data Generation using the new function\n",
- "p = 10 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n",
- "\n",
- "# Get base vector\n",
- "# template = generate_template(p, 2, 1.0)\n",
- "# template = one_hot(p)\n",
- "template = generate_fixed_template(p)\n",
- "\n",
- "# Mean center template\n",
- "template -= np.mean(template)\n",
- "\n",
- "# Generate dataset using numpy\n",
- "X, Y = DihedralDataset(p, template)\n",
- "\n",
- "# Convert to PyTorch tensors\n",
- "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n",
- "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n",
- "\n",
- "# Create a TensorDataset and DataLoader\n",
- "dataset = TensorDataset(X_tensor, Y_tensor)\n",
- "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n",
- "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n",
- "\n",
- "# Initialize model\n",
- "hidden_size = 6 if TEST_MODE else 6 * 3 # Reduced in test mode\n",
- "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-2, output_scale=1e0)\n",
- "\n",
- "# Create loss function\n",
- "loss = nn.MSELoss()\n",
- "\n",
- "# Construct optimizer\n",
- "lr, mom = 0.01, 0.9\n",
- "optimizer = optim.SGD(model.parameters(), lr=lr, momentum=mom)\n",
- "# optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n",
- "\n",
- "# Train the model\n",
- "epochs = 2 if TEST_MODE else 10000\n",
- "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//100))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "eae371c4-1405-4ac5-982c-0ebacb688ed7",
- "metadata": {},
- "source": [
- "## AGF Numerics"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "489e82e1-61c8-43e6-b260-fd96c815dec8",
- "metadata": {},
- "outputs": [],
- "source": [
- "class ModsumSubNetwork(nn.Module):\n",
- " \n",
- " def __init__(self, d_in, d_out, init_scale):\n",
- " super().__init__()\n",
- " assert d_in%2 == 0\n",
- " self.p = d_in // 2\n",
- " self.u = nn.Linear(self.p, 1, bias=False)\n",
- " self.v = nn.Linear(self.p, 1, bias=False)\n",
- " self.w = nn.Linear(1, d_out, bias=False)\n",
- " with torch.no_grad():\n",
- " self.w.weight.mul_(init_scale)\n",
- " self.u.weight.mul_(init_scale)\n",
- " self.v.weight.mul_(init_scale)\n",
- " self.active = False\n",
- " self.util_acc = 0\n",
- " self.c_a = 1/self.get_norm() - 1\n",
- " \n",
- " self.normalize()\n",
- " \n",
- " def get_norm(self):\n",
- " sqnorm = lambda x: torch.linalg.norm(x.weight)**2\n",
- " norm = torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n",
- " return norm\n",
- " \n",
- " def reinitialize(self, u, v, w):\n",
- " with torch.no_grad():\n",
- " self.u.weight.copy_(u)\n",
- " self.v.weight.copy_(v)\n",
- " self.w.weight.copy_(w)\n",
- " self.c_a = 1/self.get_norm() - 1\n",
- " \n",
- " def forward(self, x):\n",
- " x1 = x[:, :self.p]\n",
- " x2 = x[:, self.p:]\n",
- " return self.w((self.u(x1) + self.v(x2))**2)\n",
- " \n",
- " def normalize(self):\n",
- " norm = self.get_norm()\n",
- " with torch.no_grad():\n",
- " self.w.weight.div_(norm)\n",
- " self.u.weight.div_(norm)\n",
- " self.v.weight.div_(norm)\n",
- " \n",
- " def utility_step(self, x, residual, learning_rate):\n",
- " f_i = self(x)\n",
- " util = torch.einsum('nd,nd->n', f_i, residual).mean()\n",
- " self.util_acc += 3 * learning_rate * util.item()\n",
- " norm_th = 1/(1 + self.c_a - self.util_acc)\n",
- " \n",
- " util.backward()\n",
- " with torch.no_grad():\n",
- " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n",
- " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n",
- " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n",
- " self.w.weight.grad.zero_()\n",
- " self.u.weight.grad.zero_()\n",
- " self.v.weight.grad.zero_()\n",
- " self.normalize()\n",
- "\n",
- "\n",
- "class ModsumNetwork(nn.Module):\n",
- " \n",
- " def __init__(self, d_in, d_out, init_scale, width=100):\n",
- " super().__init__()\n",
- " self.d_in = d_in\n",
- " self.d_out = d_out\n",
- " self.width = width\n",
- " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n",
- " self.neurons = nn.ModuleList(neurons)\n",
- " self.set_mode(\"utilmax\")\n",
- " \n",
- " def load_init(self, U, V, W):\n",
- " for i, n in enumerate(self.neurons):\n",
- " u, v, w = U[i], V[i], W[i][:, None]\n",
- " n.reinitialize(u, v, w)\n",
- "\n",
- " def dormant(self):\n",
- " return [neuron for neuron in self.neurons if not neuron.active]\n",
- " \n",
- " def active(self):\n",
- " return [neuron for neuron in self.neurons if neuron.active]\n",
- "\n",
- " \n",
- " def set_mode(self, mode):\n",
- " if mode not in [\"utilmax\", \"costmin\"]:\n",
- " raise ValueError(\"mode must be utilmax or costmin\")\n",
- " self.mode = mode\n",
- " for neuron in self.neurons:\n",
- " grad_on = (mode==\"utilmax\") ^ neuron.active\n",
- " for param in neuron.parameters():\n",
- " param.requires_grad = grad_on\n",
- " \n",
- " def forward(self, x):\n",
- " if not np.any([n.active for n in self.neurons]):\n",
- " return torch.zeros(x.shape[0], self.d_out)\n",
- " else:\n",
- " outputs = torch.stack([neuron(x) for neuron in self.neurons if neuron.active], dim=0)\n",
- " return torch.sum(outputs, dim=0)\n",
- "\n",
- "\n",
- "def train_agf(X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None, \n",
- " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4):\n",
- " \n",
- " # Initialize\n",
- " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n",
- " if from_init:\n",
- " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n",
- " assert d_in == U.shape[1]*2\n",
- " assert d_out == W.shape[1]\n",
- " width = U.shape[0]\n",
- " net = ModsumNetwork(d_in, d_out, init_sz, width=width)#.cuda()\n",
- " net.load_init(U, V, W)\n",
- " else:\n",
- " net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)#.cuda()\n",
- " X_train.requires_grad = False\n",
- " \n",
- " def update_results(results, t):\n",
- " results[\"t\"].append(t)\n",
- " residual = (Y_train - net(X_train))\n",
- " residual = residual.detach()\n",
- " results[\"residuals\"].append(residual)\n",
- " loss = (residual**2).mean().item()\n",
- " results[\"losses\"].append(loss)\n",
- " results[\"models\"].append(net.state_dict())\n",
- " results[\"pred\"].append(net(X_train).detach().cpu().clone())\n",
- " \n",
- " results = {\n",
- " \"t\": [],\n",
- " \"residuals\": [],\n",
- " \"losses\": [],\n",
- " \"models\": [],\n",
- " \"pred\": [],\n",
- " }\n",
- "\n",
- " t = 0\n",
- " update_results(results, t)\n",
- " for _ in tqdm(range(agf_steps)):\n",
- " \n",
- " # Utility Maximization\n",
- " residual = (1/d_out) * 2*(Y_train - net(X_train))\n",
- " residual = residual.detach()\n",
- " iters = 0\n",
- " mode = \"utilmax\"\n",
- " while mode == \"utilmax\":\n",
- " for n in net.neurons:\n",
- " if n.active:\n",
- " continue\n",
- " n.utility_step(X_train, residual, utilmax_lr)\n",
- " if n.util_acc > n.c_a:\n",
- " n.active = True\n",
- " mode = \"costmin\"\n",
- " # break\n",
- " iters += 1\n",
- " net.set_mode(mode)\n",
- " t += iters\n",
- "\n",
- " # Cost Minimization\n",
- " optimizer = torch.optim.SGD(net.parameters(), lr=costmin_lr, momentum=0.9)\n",
- " for i in range(int(costmin_maxiter)):\n",
- " optimizer.zero_grad(set_to_none=False)\n",
- " residual = Y_train - net(X_train)\n",
- " loss = (residual ** 2).mean()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " net.set_mode(\"utilmax\")\n",
- "\n",
- " \n",
- " print(f\"loss: {loss.item():.5f}\")\n",
- " update_results(results, t)\n",
- "\n",
- " # Check for Termination\n",
- " if not net.dormant() or loss.item() < loss_thresh:\n",
- " break\n",
- " \n",
- " return results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5b3eddcc-8c3e-45dd-8f57-45585a021f5d",
- "metadata": {},
- "outputs": [],
- "source": [
- "costmin_lr = 0.01\n",
- "utilmax_lr = 0.1\n",
- "results = train_agf(X_tensor, Y_tensor, init_sz=model.init_scale, agf_steps=50, from_init=param_history[0],\n",
- " utilmax_lr=utilmax_lr, costmin_lr=costmin_lr,\n",
- " costmin_maxiter=1e4, loss_thresh=1e-4)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0f48aebc-a439-405a-a057-3f5c24cca91a",
- "metadata": {},
- "source": [
- "## Plot Loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff46febe-abb5-459a-bb06-a18a26afb967",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n",
- "ax.plot(list(loss_history), lw=4)\n",
- "\n",
- "# for lossval in results[\"losses\"]:\n",
- "# ax.axhline(lossval, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n",
- "\n",
- "f = utilmax_lr / (lr/(1-mom))\n",
- "# for t in results[\"t\"]:\n",
- "# ax.axvline(f*t, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n",
- "\n",
- "# times = results[\"t\"] + [epochs]\n",
- "# AGF_losses = results[\"losses\"] + [results[\"losses\"][-1]]\n",
- "# ax.step(f*np.array(times), AGF_losses, where=\"post\", lw=2, ls='dashed', color=\"k\")\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "for k, alpha in enumerate(alpha_values):\n",
- " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n",
- "\n",
- "ax.set_xscale(\"log\")\n",
- "ax.set_yscale(\"log\")\n",
- "ax.set_xlim(1e1, 1e4)\n",
- "ax.set_ylim(1e-1, 1e1)\n",
- "ax.set_xlabel('Epochs', fontsize=24)\n",
- "ax.set_ylabel('Train Loss', fontsize=24)\n",
- "\n",
- "style_axes(ax)\n",
- "plt.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "40b851e7-6256-43cd-b9f3-aca38db04917",
- "metadata": {},
- "source": [
- "## Power Spectrum of output"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "68b25ca9-6339-49dd-9d45-577a51798a25",
- "metadata": {},
- "outputs": [],
- "source": [
- "# === SETTINGS ===\n",
- "p = Y_tensor.shape[1]\n",
- "num_freqs = p // 2 + 1\n",
- "\n",
- "# Compute template power spectrum\n",
- "template_ft = np.fft.rfft(template)\n",
- "template_power = np.abs(template_ft)[:num_freqs]\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "# === Theory lines ===\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "theta0 = np.sqrt(2) * model.init_scale\n",
- "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n",
- "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n",
- "step_size = 2 * coef * lr / (1 - mom)\n",
- "\n",
- "\n",
- "# Color settings\n",
- "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n",
- "manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- "}\n",
- "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n",
- "\n",
- "# Compute output power over time (GD)\n",
- "num_points = 1000\n",
- "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n",
- "powers_over_time = []\n",
- "\n",
- "for step in steps:\n",
- " model.load_state_dict(param_history[step])\n",
- " model.eval()\n",
- " with torch.no_grad():\n",
- " outputs = model(X_tensor)\n",
- " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
- " avg_power = np.mean(np.abs(ft), axis=0)\n",
- " powers_over_time.append(avg_power)\n",
- "\n",
- "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n",
- "\n",
- "\n",
- "# # Compute output power over time (AGF)\n",
- "# f = utilmax_lr / (lr/(1-mom))\n",
- "# AGF_steps = results[\"t\"]\n",
- "# powers_over_time_AGF = []\n",
- "# for i, step in enumerate(AGF_steps):\n",
- "# outputs = results[\"pred\"][i]\n",
- "# ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
- "# avg_power = np.mean(np.abs(ft), axis=0)\n",
- "# powers_over_time_AGF.append(avg_power)\n",
- "# powers_over_time_AGF = np.array(powers_over_time_AGF) # shape: (steps, freqs)\n",
- "# AGF_steps = [f * t for t in AGF_steps]\n",
- "\n",
- "# AGF_steps.append(epochs)\n",
- "# powers_over_time_AGF = np.vstack([\n",
- "# powers_over_time_AGF,\n",
- "# powers_over_time_AGF[-1, :]\n",
- "# ])\n",
- "\n",
- "# === PLOTTING ===\n",
- "fig, ax = plt.subplots(figsize=(6, 7))\n",
- "\n",
- "for k in range(num_freqs):\n",
- " color = colors[k]\n",
- " label = fr\"$\\xi = {k}$\" if k in [1, 3, 5] else None\n",
- " ax.plot(steps, powers_over_time[:, k], color=color, lw=3, label=label)\n",
- " label_agf = 'AGF' if k == 10 else None\n",
- " #ax.step(AGF_steps, powers_over_time_AGF[:, k], color='k', lw=2, ls='dashed', where=\"post\", label=label_agf)\n",
- " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n",
- "\n",
- "for k, tau in enumerate(tau_values):\n",
- " color = colors[freq[k]]\n",
- " ax.axvline(x=tau / step_size, color=color, linestyle='dashed', linewidth=2, alpha=0.5)\n",
- "\n",
- " # Add arrow at intersection\n",
- " x = tau / step_size\n",
- " y = template_power[freq[k]]\n",
- " #draw an arrow from the lower bound to the right\n",
- " #use default color cycle\n",
- " # ax.arrow(1.04 * x, y + 0.5, 1.5 * x, 0, \n",
- " # head_width=0.2, head_length=x*0.2, length_includes_head=True,\n",
- " # fc=color, ec=color, lw=4)\n",
- "\n",
- "# # Add vertical lines if needed\n",
- "# for step in time_steps:\n",
- "# ax.axvline(x=step, color='gray', alpha=0.5, linestyle='solid', linewidth=2)\n",
- "\n",
- "# Labeling and formatting\n",
- "ax.set_xscale('log')\n",
- "ax.set_xlim(5e1, 2e6)\n",
- "ax.set_xticks([1000, 10000, 100000, epochs-1])\n",
- "ax.set_ylabel(\"Power\", fontsize=24)\n",
- "ax.set_xlabel(\"Epochs\", fontsize=24)\n",
- "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper right', bbox_to_anchor=(1, 0.9), labelspacing=0.25)\n",
- "\n",
- "style_axes(ax)\n",
- "ax.set_xticks([1000, 10000, 100000, epochs-1])\n",
- "ax.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd",
- "metadata": {},
- "source": [
- "## Plot outputs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e333d1ab-1501-434f-86d2-82c10bb58f11",
- "metadata": {},
- "outputs": [],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# Choose time steps to visualize\n",
- "# steps_to_show = [1000, 10000, 100000, epochs-1]\n",
- "steps_to_show = [100, 1000, epochs-1]\n",
- "num_samples = 1 # how many examples to plot per row\n",
- "p = Y_tensor.shape[1]\n",
- "x = np.arange(p)\n",
- "\n",
- "fig, axes = plt.subplots(len(steps_to_show), 1, figsize=(6, 6), sharex=True)\n",
- "\n",
- "for row, step in enumerate(steps_to_show):\n",
- " # Load weights at this step\n",
- " model.load_state_dict(param_history[step])\n",
- " model.eval()\n",
- "\n",
- " indices = np.random.choice(len(Y_tensor), size=num_samples, replace=False)\n",
- " with torch.no_grad():\n",
- " preds = model(X_tensor[indices]).detach().cpu().numpy()\n",
- " truths = Y_tensor[indices].detach().cpu().numpy()\n",
- "\n",
- " ax = axes[row]\n",
- " for i, idx in enumerate(indices):\n",
- " a = idx // p\n",
- " b = idx % p\n",
- " label_true = r\"$(a + b) \\cdot x$\"\n",
- " label_pred = r\"$f(a \\cdot x, b \\cdot x)$\"\n",
- "\n",
- " # Plot ground truth\n",
- " interpolate(ax, truths[i], color=f\"C{i}\", alpha=0.9, continuous=True)\n",
- " ax.scatter(x, truths[i], color=f\"C{i}\", s=30, alpha=0.9, label=label_true)\n",
- "\n",
- " # Plot prediction\n",
- " interpolate(ax, preds[i], color='k', alpha=1.0, continuous=True)\n",
- " ax.scatter(x, preds[i], color='k', s=30, alpha=0.7, label=label_pred)\n",
- "\n",
- " style_axes(ax, numyticks=3, labelsize=12)\n",
- " ax.grid(False)\n",
- " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(step))}}}$\", fontsize=20)\n",
- "\n",
- " # Only bottom row gets x-ticks\n",
- " if row < len(steps_to_show) - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # ax.legend(loc='best', fontsize=12, title=fr\"$a = {a}, b = {b}$\", handlelength=0, labelspacing=0.1, title_fontsize=14, frameon=False)\n",
- " ax.legend(\n",
- " loc='center left',\n",
- " bbox_to_anchor=(0.95, 0.5), # X slightly beyond the right edge, Y centered\n",
- " fontsize=8,\n",
- " title=fr\"$a = {a}, b = {b}$\",\n",
- " title_fontsize=10,\n",
- " handlelength=0,\n",
- " labelspacing=0.1,\n",
- " frameon=False\n",
- " )\n",
- "\n",
- "# axes[-1].set_xlabel(\"Output Index\", fontsize=20)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"predictions.pdf\", bbox_inches='tight')"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b267424b-a0e5-47e3-9e01-1dc41e05e026",
- "metadata": {},
- "source": [
- "## Plot Weights"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9de707bf-838e-4384-8150-3d8fe4586fc3",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "import matplotlib.gridspec as gridspec\n",
- "\n",
- "# Steps and corresponding highlighted frequencies\n",
- "\n",
- "#steps = [1000, 10000, 100000, epochs-1]\n",
- "steps = [100, 1000, epochs-1]\n",
- "highlight_freqs_list = [[], [1], [3], [5]]\n",
- "\n",
- "num_rows, num_cols = len(steps), 3\n",
- "\n",
- "# Use gridspec to control layout\n",
- "fig = plt.figure(figsize=(24, 6), constrained_layout=True)\n",
- "gs = gridspec.GridSpec(num_rows, num_cols, width_ratios=[1.1, 1.1, 2.0], wspace=0.1, hspace=0.1)\n",
- "axes = np.empty((num_rows, num_cols), dtype=object)\n",
- "\n",
- "# Create axes\n",
- "for row in range(num_rows):\n",
- " for col in range(num_cols):\n",
- " if col == 2:\n",
- " ax = fig.add_subplot(gs[row, col], projection='polar')\n",
- " else:\n",
- " ax = fig.add_subplot(gs[row, col]) # \u2b05 no sharex anymore\n",
- " axes[row, col] = ax\n",
- "\n",
- "num_freqs = None\n",
- "for row, index in enumerate(steps):\n",
- " highlight_freqs = highlight_freqs_list[row]\n",
- " params = param_history[index]\n",
- " W = params['W'].numpy()\n",
- " h, p = W.shape\n",
- "\n",
- " if num_freqs is None:\n",
- " num_freqs = p // 2 + 1\n",
- " cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n",
- " colors = [cmap(i) for i in range(num_freqs)]\n",
- " manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- " }\n",
- " freq_colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n",
- "\n",
- "\n",
- " # === Column 1: Weights ===\n",
- " ax = axes[row, 0]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " dom_idx = np.argmax(power)\n",
- " color = freq_colors[dom_idx]\n",
- " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n",
- " x = np.linspace(0, p - 1, 500)\n",
- " interpolate(ax, w, color=color, continuous=True, alpha=alpha)\n",
- " ax.scatter(np.arange(p), w, color=color, s=10, alpha=alpha)\n",
- " if row == 0: ax.set_title(\"Weights\", fontsize=24)\n",
- " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(index))}}}$\", fontsize=20)\n",
- " style_axes(ax, numyticks=3, numxticks=5, labelsize=12)\n",
- " ax.grid(False)\n",
- " if row < num_rows - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # === Column 2: Frequency Spectrum ===\n",
- " ax = axes[row, 1]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " for k in range(len(power)):\n",
- " color = freq_colors[k]\n",
- " ax.vlines(k, 0, power[k], linewidth=4, color=color, alpha=0.4)\n",
- " ax.scatter(k, power[k], color=color, s=50, alpha=0.7)\n",
- " # ax.axhline(0, color='gray', linewidth=1, linestyle='--', alpha=0.4)\n",
- " ax.set_xlim(-0.5, len(power) - 0.5)\n",
- " ax.set_xticks(np.arange(len(power)))\n",
- " if row == 0: ax.set_title(\"Frequency\", fontsize=24)\n",
- " style_axes(ax, numyticks=3, numxticks=11, labelsize=12)\n",
- " ax.grid(False)\n",
- " if row < num_rows - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # === Column 3: Phase Polar Plot ===\n",
- " ax = axes[row, 2]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " dom_idx = np.argmax(power)\n",
- " phase = np.angle(ft[dom_idx])\n",
- " norm = np.linalg.norm(w)\n",
- " color = freq_colors[dom_idx]\n",
- " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n",
- " ax.plot([phase, phase], [0, norm], color=color, linewidth=2, alpha=alpha)\n",
- " ax.scatter(phase, norm, color=color, s=40, alpha=alpha)\n",
- " angles = np.arange(0, 360, 45)\n",
- " # ax.set_thetagrids(angles, [f\"{a}\u00b0\" if a in [45,135,225,315] else \"\" for a in angles])\n",
- " ax.set_thetagrids(angles, [\"\" for a in angles])\n",
- " ax.set_yticklabels([])\n",
- " ax.spines['polar'].set_linewidth(2)\n",
- " if row == 0: ax.set_title(\"Phase\", fontsize=24)\n",
- "\n",
- "# Shift polar plots left to reduce whitespace\n",
- "for row in range(num_rows):\n",
- " ax = axes[row, 2]\n",
- " pos = ax.get_position()\n",
- " ax.set_position([pos.x0 - 0.155, pos.y0, pos.width, pos.height])\n",
- "\n",
- "plt.savefig(\"W-weights.pdf\", bbox_inches='tight')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9f5f56f9-7055-4056-9a18-7d91b3be50f8",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
\ No newline at end of file
diff --git a/notebooks/dn.ipynb b/notebooks/dn.ipynb
new file mode 100644
index 0000000..85d2fd5
--- /dev/null
+++ b/notebooks/dn.ipynb
@@ -0,0 +1,307 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Binary Group Composition on $D_n$ (Dihedral Group)\n",
+ "\n",
+ "**Group:** Dihedral group $D_n$ of order $2n$ (rotations and reflections of a regular $n$-gon). \n",
+ "**Task:** Given encodings of two group elements $g_1, g_2 \\in D_n$, predict the encoding of their product $g_1 \\cdot g_2$. \n",
+ "**Sequence length:** $k = 2$ (binary composition). \n",
+ "**Architecture:** `TwoLayerNet` with square nonlinearity. \n",
+ "**Key result:** The network learns one irreducible representation at a time, producing a staircase in the training loss."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from escnn.group import DihedralGroup\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "import src.dataset as dataset\n",
+ "import src.model as model\n",
+ "import src.optimizer as optimizer\n",
+ "import src.power as power\n",
+ "import src.template as template\n",
+ "import src.train as train_mod\n",
+ "import src.viz as viz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
+ "\n",
+ "seed = 0\n",
+ "random.seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "torch.manual_seed(seed)\n",
+ "\n",
+ "n = 5 # D_5 has order 2*5 = 10\n",
+ "group = DihedralGroup(n)\n",
+ "group_size = group.order()\n",
+ "\n",
+ "hidden_size = 20 if TEST_MODE else 180\n",
+ "epochs = 2 if TEST_MODE else 2000\n",
+ "lr = 0.01\n",
+ "init_scale = 1e-3\n",
+ "\n",
+ "FIGURES_DIR = \"figures\"\n",
+ "os.makedirs(FIGURES_DIR, exist_ok=True)\n",
+ "\n",
+ "print(f\"Group: D_{n}, order {group_size}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Template and Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a template with known Fourier structure on D_n\n",
+ "# D_5 has 4 irreps: two 1D and two 2D\n",
+ "# fourier_coef_diag_values: one value per irrep\n",
+ "fourier_coef_diag_values = [0.0, 5.0, 30.0, 300.0]\n",
+ "tpl = template.fixed_group(group, fourier_coef_diag_values)\n",
+ "\n",
+ "# Build exhaustive dataset: all group_size^2 pairs\n",
+ "X, Y = dataset.group_dataset(group, tpl)\n",
+ "X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)\n",
+ "\n",
+ "ds = TensorDataset(X_tensor, Y_tensor)\n",
+ "dataloader = DataLoader(ds, batch_size=len(ds), shuffle=False)\n",
+ "\n",
+ "print(f\"Dataset: {len(ds)} samples (all {group_size}x{group_size} pairs)\")\n",
+ "print(f\"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize template and its group power spectrum\n",
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
+ "\n",
+ "ax1.bar(range(group_size), tpl, color=\"black\")\n",
+ "ax1.set_xlabel(\"Group element\")\n",
+ "ax1.set_ylabel(\"Template value\")\n",
+ "ax1.set_title(f\"Template $t$ on $D_{{{n}}}$\")\n",
+ "\n",
+ "gp = power.GroupPower(tpl, group)\n",
+ "pwr = gp.group_power_spectrum()\n",
+ "ax2.bar(range(len(pwr)), pwr, color=\"steelblue\")\n",
+ "ax2.set_xlabel(\"Irrep index\")\n",
+ "ax2.set_ylabel(\"Power\")\n",
+ "ax2.set_title(\"Power spectrum (by irrep)\")\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/dihedral_template.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model and Optimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = model.TwoLayerNet(\n",
+ " group_size=group_size,\n",
+ " hidden_size=hidden_size,\n",
+ " nonlinearity=\"square\",\n",
+ " init_scale=init_scale,\n",
+ ")\n",
+ "net = net.to(device)\n",
+ "\n",
+ "criterion = nn.MSELoss()\n",
+ "opt = optimizer.PerNeuronScaledSGD(net, lr=lr, degree=3)\n",
+ "\n",
+ "print(f\"Model: TwoLayerNet(group_size={group_size}, hidden={hidden_size}, init_scale={init_scale})\")\n",
+ "print(f\"Optimizer: PerNeuronScaledSGD(lr={lr}, degree=3)\")\n",
+ "print(f\"Training for {epochs} epochs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n",
+ " net,\n",
+ " dataloader,\n",
+ " criterion,\n",
+ " opt,\n",
+ " epochs=epochs,\n",
+ " verbose_interval=max(1, epochs // 10),\n",
+ " save_param_interval=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training Loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compute theoretical loss plateau levels\n",
+ "theory_levels = gp.loss_plateau_predictions()\n",
+ "\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "ax.plot(loss_history, lw=4)\n",
+ "\n",
+ "for level in theory_levels:\n",
+ " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=2, zorder=-2)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_yscale(\"log\")\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_ylabel(\"Train Loss\", fontsize=18)\n",
+ "ax.set_title(f\"Training loss on $D_{{{n}}}$\", fontsize=20)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/dihedral_loss.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Power Spectrum Over Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use model_power_over_time from src/power.py\n",
+ "powers_over_time, power_steps = power.model_power_over_time(\n",
+ " group_name=\"dihedral\",\n",
+ " model=net,\n",
+ " param_history=param_history,\n",
+ " model_inputs=X_tensor,\n",
+ " group=group,\n",
+ ")\n",
+ "\n",
+ "# Reference: template power per irrep\n",
+ "template_pwr = gp.group_power_spectrum()\n",
+ "\n",
+ "# Plot\n",
+ "colors = [\"tab:blue\", \"tab:orange\", \"tab:red\", \"tab:green\", \"tab:brown\", \"tab:purple\"]\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "\n",
+ "n_irreps = powers_over_time.shape[1]\n",
+ "for k in range(n_irreps):\n",
+ " color = colors[k] if k < len(colors) else f\"C{k}\"\n",
+ " ax.plot(power_steps, powers_over_time[:, k], color=color, lw=4, label=rf\"$\\rho_{{{k}}}$\")\n",
+ " ax.axhline(template_pwr[k], color=color, linestyle=\"dotted\", linewidth=2, alpha=0.5, zorder=-10)\n",
+ "\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_ylabel(\"Power\", fontsize=18)\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_title(f\"Power spectrum over training on $D_{{{n}}}$\", fontsize=20)\n",
+ "ax.legend(fontsize=12, title=\"Irrep\", title_fontsize=14, loc=\"upper left\", labelspacing=0.25)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/dihedral_power_spectrum.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Irreducible Representations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = viz.plot_irreps(group, show=False)\n",
+ "plt.savefig(f\"{FIGURES_DIR}/dihedral_irreps.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "group-agf",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/fourier_power_only.svg b/notebooks/fourier_power_only.svg
deleted file mode 100644
index b39f1e3..0000000
--- a/notebooks/fourier_power_only.svg
+++ /dev/null
@@ -1,1054 +0,0 @@
-
-
-
diff --git a/notebooks/loss-without-lines.svg b/notebooks/loss-without-lines.svg
deleted file mode 100644
index e379d4c..0000000
--- a/notebooks/loss-without-lines.svg
+++ /dev/null
@@ -1,778 +0,0 @@
-
-
-
diff --git a/notebooks/modular_arithmetic.ipynb b/notebooks/modular_arithmetic.ipynb
deleted file mode 100644
index 8f4f355..0000000
--- a/notebooks/modular_arithmetic.ipynb
+++ /dev/null
@@ -1,1063 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "51d11caf-0971-4324-b63b-819b714a9c3c",
- "metadata": {},
- "source": [
- "# Modular Addition"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from tqdm import tqdm\n",
- "import torch.optim as optim\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "import seaborn as sns\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "from matplotlib.animation import FuncAnimation\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "from matplotlib.ticker import FuncFormatter\n",
- "from matplotlib.ticker import MaxNLocator"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9fd05577-db56-4d0a-bb93-1d0b48cecaf6",
- "metadata": {},
- "source": [
- "## Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f19bd1ad-9e8f-4720-b317-afe13fafae88",
- "metadata": {},
- "outputs": [],
- "source": [
- "def one_hot(p):\n",
- " \"\"\"One-hot encode an integer value in R^p.\"\"\"\n",
- " vec = np.zeros(p)\n",
- " vec[0] = 1\n",
- " return vec\n",
- "\n",
- "def generate_template(p, magnitude, exponent):\n",
- " weight = magnitude * np.power(np.arange(1, p), -exponent) # Power-law singular values\n",
- " template = np.ones(p) # Base term (DC component)\n",
- " for freq in range(1, p):\n",
- " template += weight[freq-1] * np.cos(np.arange(p) * freq / p * 2 * np.pi)\n",
- " return template / p\n",
- "\n",
- "def generate_fixed_template(p):\n",
- " # Generate template array from Fourier spectrum\n",
- " spectrum = np.zeros(p, dtype=complex)\n",
- " \n",
- " # Set only three frequencies with specific amplitudes\n",
- " spectrum[1] = 10 # Positive frequency\n",
- " spectrum[-1] = 10 # Negative frequency (conjugate)\n",
- " spectrum[3] = 5 # Second frequency\n",
- " spectrum[-3] = 5 # Its conjugate\n",
- " spectrum[5] = 2.5 # Third frequency \n",
- " spectrum[-5] = 2.5 # Its conjugate\n",
- " \n",
- " # Generate signal from spectrum\n",
- " template = np.fft.ifft(spectrum).real\n",
- "\n",
- " return template\n",
- "\n",
- "def ModularAdditionDataset(p, template):\n",
- " # Initialize data arrays\n",
- " X = np.zeros((p * p, 2, p)) # Shape: (p^2, 2, p)\n",
- " Y = np.zeros((p * p, p)) # Shape: (p^2, p)\n",
- " \n",
- " # Generate the dataset\n",
- " idx = 0\n",
- " for a in range(p):\n",
- " for b in range(p):\n",
- " q = (a + b) % p # a + b mod p\n",
- " X[idx, 0, :] = np.roll(template, a)\n",
- " X[idx, 1, :] = np.roll(template, b)\n",
- " Y[idx, :] = np.roll(template, q)\n",
- " idx += 1\n",
- " \n",
- " return X, Y"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7a0ecbbd-ceaf-4bef-af4a-13a22fa70063",
- "metadata": {},
- "source": [
- "## Architecture"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2cf22b7d-49e7-445b-8742-2e75cd1fa55a",
- "metadata": {},
- "outputs": [],
- "source": [
- "class TwoLayerNet(nn.Module):\n",
- " def __init__(self, p, hidden_size, nonlinearity='square', init_scale=1.0, output_scale=1.0):\n",
- " super(TwoLayerNet, self).__init__()\n",
- " \n",
- " # Store dimensions\n",
- " self.p = p\n",
- " self.hidden_size = hidden_size\n",
- " self.nonlinearity = nonlinearity\n",
- " self.init_scale = init_scale\n",
- " self.output_scale = output_scale\n",
- " \n",
- " # Initialize parameters \n",
- " self.U = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # First p elements\n",
- " self.V = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(2 * p)) # Second p elements\n",
- " self.W = nn.Parameter(self.init_scale * torch.randn(hidden_size, p) / np.sqrt(p)) # Second layer weights\n",
- " print(f\"Initialized U with shape {self.U.shape}\")\n",
- " print(f\"Initialized V with shape {self.V.shape}\")\n",
- " print(f\"Initialized W with shape {self.W.shape}\")\n",
- "\n",
- " def forward(self, x):\n",
- " print(f\"Input x shape: {x.shape}\")\n",
- " # First layer (linear and combined)\n",
- " x1 = x[:, :self.p] @ self.U.T\n",
- " print(f\"x1 (x @ U.T) shape: {x1.shape}\")\n",
- " x2 = x[:, self.p:] @ self.V.T\n",
- " print(f\"x2 (x @ V.T) shape: {x2.shape}\")\n",
- " x_combined = x1 + x2\n",
- " print(f\"x_combined (x1 + x2) shape: {x_combined.shape}\")\n",
- "\n",
- " # Apply nonlinearity activation\n",
- " if self.nonlinearity == 'relu':\n",
- " x_combined = torch.relu(x_combined)\n",
- " print(\"Applied ReLU nonlinearity\")\n",
- " elif self.nonlinearity == 'square':\n",
- " x_combined = x_combined**2\n",
- " print(\"Applied square nonlinearity\")\n",
- " elif self.nonlinearity == 'linear':\n",
- " x_combined = x_combined\n",
- " print(\"Applied linear (identity) nonlinearity\")\n",
- " elif self.nonlinearity == 'tanh':\n",
- " x_combined = torch.tanh(x_combined)\n",
- " print(\"Applied tanh nonlinearity\")\n",
- " elif self.nonlinearity == 'gelu':\n",
- " gelu = torch.nn.GELU()\n",
- " x_combined = gelu(x_combined)\n",
- " print(\"Applied GELU nonlinearity\")\n",
- " else:\n",
- " raise ValueError(f\"Invalid nonlinearity '{self.nonlinearity}' provided.\")\n",
- "\n",
- " # Second layer (linear)\n",
- " x_out = x_combined @ self.W\n",
- " print(f\"x_out (x_combined @ W) shape: {x_out.shape}\")\n",
- "\n",
- " # Feature learning scaling\n",
- " x_out *= self.output_scale\n",
- " print(f\"x_out after scaling with output_scale={self.output_scale}: shape {x_out.shape}\")\n",
- " \n",
- " return x_out"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f7e7336b-5c6e-48af-a357-2b2c877f6168",
- "metadata": {},
- "source": [
- "## Optimization"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1035f81c-e877-4655-8640-4e4c3d323af8",
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_accuracy(model, dataloader):\n",
- " correct = 0\n",
- " total = 0\n",
- " print(\"Starting test_accuracy evaluation...\")\n",
- " \n",
- " with torch.no_grad(): # Disable gradient calculation for evaluation\n",
- " for i, (inputs, labels) in enumerate(dataloader):\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- " print(f\"Batch {i+1}: inputs reshaped\")\n",
- " outputs = model(inputs)\n",
- " print(f\"Batch {i+1}: model forward pass done\")\n",
- " _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class)\n",
- " _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding\n",
- " correct += (predicted == true_labels).sum().item()\n",
- " total += labels.size(0)\n",
- " print(f\"Batch {i+1}: accuracy updated (correct={correct}, total={total})\")\n",
- " \n",
- " accuracy = 100 * correct / total\n",
- " print(f\"Final test accuracy: {accuracy:.2f}%\")\n",
- " return accuracy\n",
- "\n",
- "def train(model, dataloader, criterion, optimizer, epochs=100, verbose_interval=10):\n",
- " print(\"Starting training loop...\")\n",
- " model.train() # Set the model to training mode\n",
- " print(\"Model set to train mode.\")\n",
- " loss_history = [] # List to store loss values\n",
- " accuracy_history = []\n",
- " param_history = []\n",
- "\n",
- " for epoch in range(epochs):\n",
- " print(f\"Epoch {epoch+1} started.\")\n",
- " running_loss = 0.0\n",
- " for batch_idx, (inputs, labels) in enumerate(dataloader):\n",
- " inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers\n",
- " print(f\" Batch {batch_idx+1}: inputs reshaped\")\n",
- "\n",
- " optimizer.zero_grad() # Zero gradients\n",
- " print(f\" Batch {batch_idx+1}: optimizer gradients zeroed\")\n",
- " outputs = model(inputs) # Forward pass\n",
- " print(f\" Batch {batch_idx+1}: model forward pass done\")\n",
- " loss = criterion(outputs, labels) # Compute loss\n",
- " print(f\" Batch {batch_idx+1}: loss computed ({loss.item():.4f})\")\n",
- " loss.backward() # Backpropagation\n",
- " print(f\" Batch {batch_idx+1}: backward pass done\")\n",
- " optimizer.step() # Update weights\n",
- " print(f\" Batch {batch_idx+1}: optimizer step done\")\n",
- "\n",
- " running_loss += loss.item()\n",
- " print(f\" Batch {batch_idx+1}: running_loss updated ({running_loss:.4f})\")\n",
- "\n",
- " # Append the average loss for the epoch to loss_history\n",
- " avg_loss = running_loss / len(dataloader)\n",
- " loss_history.append(avg_loss)\n",
- " print(f\"Epoch {epoch+1}: avg_loss appended ({avg_loss:.4f})\")\n",
- "\n",
- " # Append the accuracy\n",
- " model.eval()\n",
- " print(f\"Epoch {epoch+1}: model set to eval mode for accuracy check\")\n",
- " accuracy = test_accuracy(model, dataloader)\n",
- " accuracy_history.append(accuracy)\n",
- " print(f\"Epoch {epoch+1}: accuracy appended ({accuracy:.2f}%)\")\n",
- " model.train()\n",
- " print(f\"Epoch {epoch+1}: model set back to train mode\")\n",
- "\n",
- " # Save current model parameters\n",
- " current_params = {\n",
- " \"U\": model.U.detach().cpu().clone(),\n",
- " \"V\": model.V.detach().cpu().clone(),\n",
- " \"W\": model.W.detach().cpu().clone()\n",
- " }\n",
- " param_history.append(current_params)\n",
- " print(f\"Epoch {epoch+1}: model parameters saved\")\n",
- "\n",
- " # Print verbose information every `verbose_interval` epochs\n",
- " if (epoch + 1) % verbose_interval == 0:\n",
- " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\")\n",
- "\n",
- " print(\"Training loop finished.\")\n",
- " return loss_history, accuracy_history, param_history # Return loss history for plotting"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0e86c4f6-83a6-4465-abf0-7d104432cc9c",
- "metadata": {},
- "source": [
- "## Plotting functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "014e2d10-9550-4fd4-adb7-168a27fda1b3",
- "metadata": {},
- "outputs": [],
- "source": [
- "def style_axes(ax, numyticks=5, numxticks=5, labelsize=24):\n",
- " # Y-axis ticks\n",
- " ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks))\n",
- " \n",
- " # X-axis ticks\n",
- " ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
- " labelbottom=True, left=True, right=False,\n",
- " labelleft=True, direction='out', length=7, width=1.5, pad=8, labelsize=labelsize)\n",
- " ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks))\n",
- "\n",
- " # Scientific notation formatting\n",
- " if ax.get_yscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='y', scilimits=(-2, 2))\n",
- " if ax.get_xscale() == 'linear':\n",
- " ax.ticklabel_format(style='sci', axis='x', scilimits=(-2, 2))\n",
- "\n",
- " ax.xaxis.offsetText.set_fontsize(20)\n",
- " ax.grid()\n",
- "\n",
- " # Customize spines\n",
- " for spine in [\"top\", \"right\"]:\n",
- " ax.spines[spine].set_visible(False)\n",
- " for spine in [\"left\", \"bottom\"]:\n",
- " ax.spines[spine].set_linewidth(3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "20989d96-f34f-4be7-a0f9-4b92fb7f235a",
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_power(points):\n",
- " p = len(points)\n",
- " num_coefficients = (p // 2) + 1\n",
- " \n",
- " # Perform FFT and calculate power spectrum\n",
- " ft = np.fft.fft(points) # Could consider using np.fft.rfft which is designed for real valued input.\n",
- " power = np.abs(ft[:num_coefficients])**2 / p\n",
- " \n",
- " # Double power for frequencies strictly between 0 and Nyquist (Nyquist is not doubled if p is even)\n",
- " if p % 2 == 0: # p is even, Nyquist frequency at index num_coefficients - 1\n",
- " power[1:num_coefficients - 1] *= 2\n",
- " else: # p is odd, no Nyquist frequency\n",
- " power[1:] *= 2\n",
- "\n",
- " # Confirm the power sum approximates the squared norm of points\n",
- " total_power = np.sum(power)\n",
- " norm_squared = np.linalg.norm(points)**2\n",
- " if not np.isclose(total_power, norm_squared, rtol=1e-3):\n",
- " print(f\"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}\")\n",
- "\n",
- " return np.arange(num_coefficients), power\n",
- "\n",
- "def interpolate(ax, points, color, continuous, alpha=1.0):\n",
- " p = len(points)\n",
- " if continuous:\n",
- " # Perform Fourier Transform\n",
- " ft = np.fft.fft(points)\n",
- " \n",
- " # Keep only non-negative frequencies (first half + Nyquist if p is even)\n",
- " num_coefficients = (p // 2) + 1\n",
- " ft = ft[:num_coefficients] # Truncate to keep non-negative frequencies\n",
- " \n",
- " # Create a dense set of x-values for smooth interpolation\n",
- " xs = np.linspace(0, p, 10 * p) # 10 times more points than the original for smoothness\n",
- " curr_val = np.zeros(xs.shape, dtype=complex)\n",
- " \n",
- " # Use only non-negative frequencies for interpolation\n",
- " for freq in range(num_coefficients):\n",
- " theta = np.angle(ft[freq])\n",
- " r = np.abs(ft[freq]) / p\n",
- " # Double amplitude except for DC (freq = 0) and Nyquist (freq = p / 2, when p is even)\n",
- " if freq > 0 and (freq < p / 2 or p % 2 != 0):\n",
- " r *= 2\n",
- " curr_val += r * np.exp(1j * ((2 * np.pi * freq * xs / p) + theta))\n",
- "\n",
- " # Plot the real part (since output is real-valued)\n",
- " ax.plot(xs, curr_val.real, color=color, alpha=alpha)\n",
- " else:\n",
- " ax.plot(np.arange(p), points, color=color, alpha=alpha) "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e99dae27-f8fe-403a-b70f-0bcaf818cbe7",
- "metadata": {},
- "source": [
- "## Gradient Descent Experiment"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bcd15c5a-5745-41ba-b015-48e403160c7e",
- "metadata": {},
- "outputs": [],
- "source": [
- "seed = 0 # or any integer you like\n",
- "random.seed(seed)\n",
- "np.random.seed(seed)\n",
- "torch.manual_seed(seed)\n",
- "torch.cuda.manual_seed_all(seed) # if using GPU\n",
- "\n",
- "# Data Generation using the new function\n",
- "# TEST_MODE: Reduce p and hidden_size for faster automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "p = 20 # Keep same value in TEST_MODE to avoid index errors # Modulus (reduced in test mode)\n",
- "\n",
- "# Get base vector\n",
- "# template = generate_template(p, 2, 1.0)\n",
- "# template = one_hot(p)\n",
- "template = generate_fixed_template(p)\n",
- "\n",
- "# Mean center template\n",
- "template -= np.mean(template)\n",
- "\n",
- "# Generate dataset using numpy\n",
- "X, Y = ModularAdditionDataset(p, template)\n",
- "\n",
- "# Convert to PyTorch tensors\n",
- "X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 2 * p) # Flatten input (num_samples, 2*p)\n",
- "Y_tensor = torch.tensor(Y, dtype=torch.float32) # Targets (num_samples, p)\n",
- "\n",
- "# Create a TensorDataset and DataLoader\n",
- "dataset = TensorDataset(X_tensor, Y_tensor)\n",
- "dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n",
- "# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)\n",
- "\n",
- "# Initialize model\n",
- "hidden_size = 6 if TEST_MODE else 6 * 3 # Reduced in test mode\n",
- "model = TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=1e-2, output_scale=1e0)\n",
- "\n",
- "# Create loss function\n",
- "loss = nn.MSELoss()\n",
- "\n",
- "# Construct optimizer\n",
- "lr, mom = 0.01, 0.9\n",
- "optimizer = optim.SGD(model.parameters(), lr=lr, momentum=mom)\n",
- "# optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n",
- "\n",
- "# Train the model\n",
- "# TEST_MODE: Set to reduce epochs for automated testing\n",
- "import os\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "epochs = 2 if TEST_MODE else 1000001\n",
- "loss_history, accuracy_history, param_history = train(model, dataloader, loss, optimizer, epochs=epochs, verbose_interval=max(1, epochs//100))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "eae371c4-1405-4ac5-982c-0ebacb688ed7",
- "metadata": {},
- "source": [
- "## AGF Numerics"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "489e82e1-61c8-43e6-b260-fd96c815dec8",
- "metadata": {},
- "outputs": [],
- "source": [
- "class ModsumSubNetwork(nn.Module):\n",
- " \n",
- " def __init__(self, d_in, d_out, init_scale):\n",
- " super().__init__()\n",
- " assert d_in%2 == 0\n",
- " self.p = d_in // 2\n",
- " self.u = nn.Linear(self.p, 1, bias=False)\n",
- " self.v = nn.Linear(self.p, 1, bias=False)\n",
- " self.w = nn.Linear(1, d_out, bias=False)\n",
- " with torch.no_grad():\n",
- " self.w.weight.mul_(init_scale)\n",
- " self.u.weight.mul_(init_scale)\n",
- " self.v.weight.mul_(init_scale)\n",
- " self.active = False\n",
- " self.util_acc = 0\n",
- " self.c_a = 1/self.get_norm() - 1\n",
- " \n",
- " self.normalize()\n",
- " \n",
- " def get_norm(self):\n",
- " sqnorm = lambda x: torch.linalg.norm(x.weight)**2\n",
- " norm = torch.sqrt(sqnorm(self.w) + sqnorm(self.u) + sqnorm(self.v))\n",
- " return norm\n",
- " \n",
- " def reinitialize(self, u, v, w):\n",
- " with torch.no_grad():\n",
- " self.u.weight.copy_(u)\n",
- " self.v.weight.copy_(v)\n",
- " self.w.weight.copy_(w)\n",
- " self.c_a = 1/self.get_norm() - 1\n",
- " \n",
- " def forward(self, x):\n",
- " x1 = x[:, :self.p]\n",
- " x2 = x[:, self.p:]\n",
- " return self.w((self.u(x1) + self.v(x2))**2)\n",
- " \n",
- " def normalize(self):\n",
- " norm = self.get_norm()\n",
- " with torch.no_grad():\n",
- " self.w.weight.div_(norm)\n",
- " self.u.weight.div_(norm)\n",
- " self.v.weight.div_(norm)\n",
- " \n",
- " def utility_step(self, x, residual, learning_rate):\n",
- " f_i = self(x)\n",
- " util = torch.einsum('nd,nd->n', f_i, residual).mean()\n",
- " self.util_acc += 3 * learning_rate * util.item()\n",
- " norm_th = 1/(1 + self.c_a - self.util_acc)\n",
- " \n",
- " util.backward()\n",
- " with torch.no_grad():\n",
- " self.w.weight += norm_th * learning_rate * self.w.weight.grad\n",
- " self.u.weight += norm_th * learning_rate * self.u.weight.grad\n",
- " self.v.weight += norm_th * learning_rate * self.v.weight.grad\n",
- " self.w.weight.grad.zero_()\n",
- " self.u.weight.grad.zero_()\n",
- " self.v.weight.grad.zero_()\n",
- " self.normalize()\n",
- "\n",
- "\n",
- "class ModsumNetwork(nn.Module):\n",
- " \n",
- " def __init__(self, d_in, d_out, init_scale, width=100):\n",
- " super().__init__()\n",
- " self.d_in = d_in\n",
- " self.d_out = d_out\n",
- " self.width = width\n",
- " neurons = [ModsumSubNetwork(d_in, d_out, init_scale) for _ in range(width)]\n",
- " self.neurons = nn.ModuleList(neurons)\n",
- " self.set_mode(\"utilmax\")\n",
- " \n",
- " def load_init(self, U, V, W):\n",
- " for i, n in enumerate(self.neurons):\n",
- " u, v, w = U[i], V[i], W[i][:, None]\n",
- " n.reinitialize(u, v, w)\n",
- "\n",
- " def dormant(self):\n",
- " return [neuron for neuron in self.neurons if not neuron.active]\n",
- " \n",
- " def active(self):\n",
- " return [neuron for neuron in self.neurons if neuron.active]\n",
- "\n",
- " \n",
- " def set_mode(self, mode):\n",
- " if mode not in [\"utilmax\", \"costmin\"]:\n",
- " raise ValueError(\"mode must be utilmax or costmin\")\n",
- " self.mode = mode\n",
- " for neuron in self.neurons:\n",
- " grad_on = (mode==\"utilmax\") ^ neuron.active\n",
- " for param in neuron.parameters():\n",
- " param.requires_grad = grad_on\n",
- " \n",
- " def forward(self, x):\n",
- " if not np.any([n.active for n in self.neurons]):\n",
- " return torch.zeros(x.shape[0], self.d_out)\n",
- " else:\n",
- " outputs = torch.stack([neuron(x) for neuron in self.neurons if neuron.active], dim=0)\n",
- " return torch.sum(outputs, dim=0)\n",
- "\n",
- "\n",
- "def train_agf(X_train, Y_train, init_sz=1e-3, agf_steps=5, from_init=None, \n",
- " utilmax_lr=1, costmin_lr=1, costmin_maxiter=1e4, loss_thresh=1e-4):\n",
- " \n",
- " # Initialize\n",
- " d_in, d_out = X_train.shape[-1], Y_train.shape[-1]\n",
- " if from_init:\n",
- " U, V, W = from_init[\"U\"], from_init[\"V\"], from_init[\"W\"]\n",
- " assert d_in == U.shape[1]*2\n",
- " assert d_out == W.shape[1]\n",
- " width = U.shape[0]\n",
- " net = ModsumNetwork(d_in, d_out, init_sz, width=width)#.cuda()\n",
- " net.load_init(U, V, W)\n",
- " else:\n",
- " net = ModsumNetwork(d_in, d_out, init_sz, width=agf_steps)#.cuda()\n",
- " X_train.requires_grad = False\n",
- " \n",
- " def update_results(results, t):\n",
- " results[\"t\"].append(t)\n",
- " residual = (Y_train - net(X_train))\n",
- " residual = residual.detach()\n",
- " results[\"residuals\"].append(residual)\n",
- " loss = (residual**2).mean().item()\n",
- " results[\"losses\"].append(loss)\n",
- " results[\"models\"].append(net.state_dict())\n",
- " results[\"pred\"].append(net(X_train).detach().cpu().clone())\n",
- " \n",
- " results = {\n",
- " \"t\": [],\n",
- " \"residuals\": [],\n",
- " \"losses\": [],\n",
- " \"models\": [],\n",
- " \"pred\": [],\n",
- " }\n",
- "\n",
- " t = 0\n",
- " update_results(results, t)\n",
- " for _ in tqdm(range(agf_steps)):\n",
- " \n",
- " # Utility Maximization\n",
- " residual = (1/d_out) * 2*(Y_train - net(X_train))\n",
- " residual = residual.detach()\n",
- " iters = 0\n",
- " mode = \"utilmax\"\n",
- " while mode == \"utilmax\":\n",
- " for n in net.neurons:\n",
- " if n.active:\n",
- " continue\n",
- " n.utility_step(X_train, residual, utilmax_lr)\n",
- " if n.util_acc > n.c_a:\n",
- " n.active = True\n",
- " mode = \"costmin\"\n",
- " # break\n",
- " iters += 1\n",
- " net.set_mode(mode)\n",
- " t += iters\n",
- "\n",
- " # Cost Minimization\n",
- " optimizer = torch.optim.SGD(net.parameters(), lr=costmin_lr, momentum=0.9)\n",
- " for i in range(int(costmin_maxiter)):\n",
- " optimizer.zero_grad(set_to_none=False)\n",
- " residual = Y_train - net(X_train)\n",
- " loss = (residual ** 2).mean()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " net.set_mode(\"utilmax\")\n",
- "\n",
- " \n",
- " print(f\"loss: {loss.item():.5f}\")\n",
- " update_results(results, t)\n",
- "\n",
- " # Check for Termination\n",
- " if not net.dormant() or loss.item() < loss_thresh:\n",
- " break\n",
- " \n",
- " return results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5b3eddcc-8c3e-45dd-8f57-45585a021f5d",
- "metadata": {},
- "outputs": [],
- "source": [
- "costmin_lr = 0.01\n",
- "utilmax_lr = 0.1\n",
- "results = train_agf(X_tensor, Y_tensor, init_sz=model.init_scale, agf_steps=50, from_init=param_history[0],\n",
- " utilmax_lr=utilmax_lr, costmin_lr=costmin_lr,\n",
- " costmin_maxiter=1e4, loss_thresh=1e-4)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0f48aebc-a439-405a-a057-3f5c24cca91a",
- "metadata": {},
- "source": [
- "## Plot Loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff46febe-abb5-459a-bb06-a18a26afb967",
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(1, 1, figsize=(6, 6))\n",
- "ax.plot(list(loss_history), lw=4)\n",
- "\n",
- "for lossval in results[\"losses\"]:\n",
- " ax.axhline(lossval, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n",
- "\n",
- "f = utilmax_lr / (lr/(1-mom))\n",
- "for t in results[\"t\"]:\n",
- " ax.axvline(f*t, alpha=0.3, ls=\":\", color=\"xkcd:slate\", zorder=-4, lw=2)\n",
- "\n",
- "times = results[\"t\"] + [epochs]\n",
- "AGF_losses = results[\"losses\"] + [results[\"losses\"][-1]]\n",
- "ax.step(f*np.array(times), AGF_losses, where=\"post\", lw=2, ls='dashed', color=\"k\")\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "for k, alpha in enumerate(alpha_values):\n",
- " ax.axhline(y=coef * alpha, color='black', linestyle='--', linewidth=2, zorder=-2)\n",
- "\n",
- "ax.set_xscale(\"log\")\n",
- "ax.set_yscale(\"log\")\n",
- "ax.set_xlim(1e1, 1e6)\n",
- "ax.set_ylim(1e-3, 1e0)\n",
- "ax.set_xlabel('Epochs', fontsize=24)\n",
- "ax.set_ylabel('Train Loss', fontsize=24)\n",
- "\n",
- "style_axes(ax)\n",
- "plt.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"loss-without-lines.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "40b851e7-6256-43cd-b9f3-aca38db04917",
- "metadata": {},
- "source": [
- "## Power Spectrum of output"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "68b25ca9-6339-49dd-9d45-577a51798a25",
- "metadata": {},
- "outputs": [],
- "source": [
- "# === SETTINGS ===\n",
- "p = Y_tensor.shape[1]\n",
- "num_freqs = p // 2 + 1\n",
- "\n",
- "# Compute template power spectrum\n",
- "template_ft = np.fft.rfft(template)\n",
- "template_power = np.abs(template_ft)[:num_freqs]\n",
- "\n",
- "# === Compute power spectrum of template ===\n",
- "freq, power = get_power(template)\n",
- "valid = power > 1e-20\n",
- "freq, power = freq[valid], power[valid]\n",
- "sorted_idx = np.argsort(-power)\n",
- "freq, power = freq[sorted_idx], power[sorted_idx]\n",
- "\n",
- "# === Theory lines ===\n",
- "alpha_values = [np.sum(power[k:]) for k in range(len(power))]\n",
- "coef = 1 / p\n",
- "theta0 = np.sqrt(2) * model.init_scale\n",
- "uMax = [np.sqrt(2 * p / 27) * (p * power[k] / 2)**(3/2) / p**2 for k in range(len(power))]\n",
- "tau_values = [(1 / theta0 - 1) / (3 * uMax[k]) for k in range(len(uMax))]\n",
- "step_size = 2 * coef * lr / (1 - mom)\n",
- "\n",
- "\n",
- "# Color settings\n",
- "cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n",
- "manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- "}\n",
- "colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n",
- "\n",
- "# Compute output power over time (GD)\n",
- "num_points = 1000\n",
- "steps = np.unique(np.logspace(0, np.log10(len(param_history) - 1), num_points, dtype=int))\n",
- "powers_over_time = []\n",
- "\n",
- "for step in steps:\n",
- " model.load_state_dict(param_history[step])\n",
- " model.eval()\n",
- " with torch.no_grad():\n",
- " outputs = model(X_tensor)\n",
- " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
- " avg_power = np.mean(np.abs(ft), axis=0)\n",
- " powers_over_time.append(avg_power)\n",
- "\n",
- "powers_over_time = np.array(powers_over_time) # shape: (steps, freqs)\n",
- "\n",
- "\n",
- "# Compute output power over time (AGF)\n",
- "f = utilmax_lr / (lr/(1-mom))\n",
- "AGF_steps = results[\"t\"]\n",
- "powers_over_time_AGF = []\n",
- "for i, step in enumerate(AGF_steps):\n",
- " outputs = results[\"pred\"][i]\n",
- " ft = np.fft.rfft(outputs.detach().cpu().numpy(), axis=1)\n",
- " avg_power = np.mean(np.abs(ft), axis=0)\n",
- " powers_over_time_AGF.append(avg_power)\n",
- "powers_over_time_AGF = np.array(powers_over_time_AGF) # shape: (steps, freqs)\n",
- "AGF_steps = [f * t for t in AGF_steps]\n",
- "\n",
- "AGF_steps.append(epochs)\n",
- "powers_over_time_AGF = np.vstack([\n",
- " powers_over_time_AGF,\n",
- " powers_over_time_AGF[-1, :]\n",
- "])\n",
- "\n",
- "# === PLOTTING ===\n",
- "fig, ax = plt.subplots(figsize=(6, 7))\n",
- "\n",
- "for k in range(num_freqs):\n",
- " color = colors[k]\n",
- " label = fr\"$\\xi = {k}$\" if k in [1, 3, 5] else None\n",
- " ax.plot(steps, powers_over_time[:, k], color=color, lw=3, label=label)\n",
- " label_agf = 'AGF' if k == 10 else None\n",
- " ax.step(AGF_steps, powers_over_time_AGF[:, k], color='k', lw=2, ls='dashed', where=\"post\", label=label_agf)\n",
- " ax.axhline(template_power[k], color=color, linestyle='dotted', linewidth=2, alpha=0.5, zorder=-10)\n",
- "\n",
- "for k, tau in enumerate(tau_values):\n",
- " color = colors[freq[k]]\n",
- " ax.axvline(x=tau / step_size, color=color, linestyle='dashed', linewidth=2, alpha=0.5)\n",
- "\n",
- " # Add arrow at intersection\n",
- " x = tau / step_size\n",
- " y = template_power[freq[k]]\n",
- " #draw an arrow from the lower bound to the right\n",
- " #use default color cycle\n",
- " ax.arrow(1.04 * x, y + 0.5, 1.5 * x, 0, \n",
- " head_width=0.2, head_length=x*0.2, length_includes_head=True,\n",
- " fc=color, ec=color, lw=4)\n",
- "\n",
- "# # Add vertical lines if needed\n",
- "# for step in time_steps:\n",
- "# ax.axvline(x=step, color='gray', alpha=0.5, linestyle='solid', linewidth=2)\n",
- "\n",
- "# Labeling and formatting\n",
- "ax.set_xscale('log')\n",
- "ax.set_xlim(5e1, 2e6)\n",
- "ax.set_xticks([1000, 10000, 100000, epochs-1])\n",
- "ax.set_ylabel(\"Power\", fontsize=24)\n",
- "ax.set_xlabel(\"Epochs\", fontsize=24)\n",
- "ax.legend(fontsize=14, title=\"Frequency\", title_fontsize=16, loc='upper right', bbox_to_anchor=(1, 0.9), labelspacing=0.25)\n",
- "\n",
- "style_axes(ax)\n",
- "ax.set_xticks([1000, 10000, 100000, epochs-1])\n",
- "ax.grid(False)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"fourier_power_only.pdf\", bbox_inches=\"tight\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5ef2c971-d9f1-41e6-b8eb-4e467496ccfd",
- "metadata": {},
- "source": [
- "## Plot outputs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e333d1ab-1501-434f-86d2-82c10bb58f11",
- "metadata": {},
- "outputs": [],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# Choose time steps to visualize\n",
- "steps_to_show = [1000, 10000, 100000, epochs-1]\n",
- "num_samples = 1 # how many examples to plot per row\n",
- "p = Y_tensor.shape[1]\n",
- "x = np.arange(p)\n",
- "\n",
- "fig, axes = plt.subplots(len(steps_to_show), 1, figsize=(6, 6), sharex=True)\n",
- "\n",
- "for row, step in enumerate(steps_to_show):\n",
- " # Load weights at this step\n",
- " model.load_state_dict(param_history[step])\n",
- " model.eval()\n",
- "\n",
- " indices = np.random.choice(len(Y_tensor), size=num_samples, replace=False)\n",
- " with torch.no_grad():\n",
- " preds = model(X_tensor[indices]).detach().cpu().numpy()\n",
- " truths = Y_tensor[indices].detach().cpu().numpy()\n",
- "\n",
- " ax = axes[row]\n",
- " for i, idx in enumerate(indices):\n",
- " a = idx // p\n",
- " b = idx % p\n",
- " label_true = r\"$(a + b) \\cdot x$\"\n",
- " label_pred = r\"$f(a \\cdot x, b \\cdot x)$\"\n",
- "\n",
- " # Plot ground truth\n",
- " interpolate(ax, truths[i], color=f\"C{i}\", alpha=0.9, continuous=True)\n",
- " ax.scatter(x, truths[i], color=f\"C{i}\", s=30, alpha=0.9, label=label_true)\n",
- "\n",
- " # Plot prediction\n",
- " interpolate(ax, preds[i], color='k', alpha=1.0, continuous=True)\n",
- " ax.scatter(x, preds[i], color='k', s=30, alpha=0.7, label=label_pred)\n",
- "\n",
- " style_axes(ax, numyticks=3, labelsize=12)\n",
- " ax.grid(False)\n",
- " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(step))}}}$\", fontsize=20)\n",
- "\n",
- " # Only bottom row gets x-ticks\n",
- " if row < len(steps_to_show) - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # ax.legend(loc='best', fontsize=12, title=fr\"$a = {a}, b = {b}$\", handlelength=0, labelspacing=0.1, title_fontsize=14, frameon=False)\n",
- " ax.legend(\n",
- " loc='center left',\n",
- " bbox_to_anchor=(0.95, 0.5), # X slightly beyond the right edge, Y centered\n",
- " fontsize=8,\n",
- " title=fr\"$a = {a}, b = {b}$\",\n",
- " title_fontsize=10,\n",
- " handlelength=0,\n",
- " labelspacing=0.1,\n",
- " frameon=False\n",
- " )\n",
- "\n",
- "# axes[-1].set_xlabel(\"Output Index\", fontsize=20)\n",
- "plt.tight_layout()\n",
- "plt.savefig(\"predictions.pdf\", bbox_inches='tight')"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b267424b-a0e5-47e3-9e01-1dc41e05e026",
- "metadata": {},
- "source": [
- "## Plot Weights"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9de707bf-838e-4384-8150-3d8fe4586fc3",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "import matplotlib.gridspec as gridspec\n",
- "\n",
- "# Steps and corresponding highlighted frequencies\n",
- "\n",
- "steps = [1000, 10000, 100000, epochs-1]\n",
- "highlight_freqs_list = [[], [1], [3], [5]]\n",
- "\n",
- "num_rows, num_cols = len(steps), 3\n",
- "\n",
- "# Use gridspec to control layout\n",
- "fig = plt.figure(figsize=(24, 6), constrained_layout=True)\n",
- "gs = gridspec.GridSpec(num_rows, num_cols, width_ratios=[1.1, 1.1, 2.0], wspace=0.1, hspace=0.1)\n",
- "axes = np.empty((num_rows, num_cols), dtype=object)\n",
- "\n",
- "# Create axes\n",
- "for row in range(num_rows):\n",
- " for col in range(num_cols):\n",
- " if col == 2:\n",
- " ax = fig.add_subplot(gs[row, col], projection='polar')\n",
- " else:\n",
- " ax = fig.add_subplot(gs[row, col]) # ⬅ no sharex anymore\n",
- " axes[row, col] = ax\n",
- "\n",
- "num_freqs = None\n",
- "for row, index in enumerate(steps):\n",
- " highlight_freqs = highlight_freqs_list[row]\n",
- " params = param_history[index]\n",
- " W = params['W'].numpy()\n",
- " h, p = W.shape\n",
- "\n",
- " if num_freqs is None:\n",
- " num_freqs = p // 2 + 1\n",
- " cmap = plt.colormaps.get_cmap('tab20').resampled(num_freqs)\n",
- " colors = [cmap(i) for i in range(num_freqs)]\n",
- " manual_colors = {\n",
- " 0: 'tab:blue',\n",
- " 1: 'tab:orange',\n",
- " 2: 'tab:red',\n",
- " 3: 'tab:green',\n",
- " 4: 'tab:brown',\n",
- " 5: 'tab:purple',\n",
- " }\n",
- " freq_colors = [manual_colors.get(i, cmap(i)) for i in range(num_freqs)]\n",
- "\n",
- "\n",
- " # === Column 1: Weights ===\n",
- " ax = axes[row, 0]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " dom_idx = np.argmax(power)\n",
- " color = freq_colors[dom_idx]\n",
- " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n",
- " x = np.linspace(0, p - 1, 500)\n",
- " interpolate(ax, w, color=color, continuous=True, alpha=alpha)\n",
- " ax.scatter(np.arange(p), w, color=color, s=10, alpha=alpha)\n",
- " if row == 0: ax.set_title(\"Weights\", fontsize=24)\n",
- " ax.set_ylabel(fr\"$t = 10^{{{int(np.log10(index))}}}$\", fontsize=20)\n",
- " style_axes(ax, numyticks=3, numxticks=5, labelsize=12)\n",
- " ax.grid(False)\n",
- " if row < num_rows - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # === Column 2: Frequency Spectrum ===\n",
- " ax = axes[row, 1]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " for k in range(len(power)):\n",
- " color = freq_colors[k]\n",
- " ax.vlines(k, 0, power[k], linewidth=4, color=color, alpha=0.4)\n",
- " ax.scatter(k, power[k], color=color, s=50, alpha=0.7)\n",
- " # ax.axhline(0, color='gray', linewidth=1, linestyle='--', alpha=0.4)\n",
- " ax.set_xlim(-0.5, len(power) - 0.5)\n",
- " ax.set_xticks(np.arange(len(power)))\n",
- " if row == 0: ax.set_title(\"Frequency\", fontsize=24)\n",
- " style_axes(ax, numyticks=3, numxticks=11, labelsize=12)\n",
- " ax.grid(False)\n",
- " if row < num_rows - 1:\n",
- " ax.tick_params(labelbottom=False)\n",
- "\n",
- " # === Column 3: Phase Polar Plot ===\n",
- " ax = axes[row, 2]\n",
- " for i in range(h):\n",
- " w = W[i, :]\n",
- " ft = np.fft.rfft(w)\n",
- " power = np.abs(ft)**2\n",
- " dom_idx = np.argmax(power)\n",
- " phase = np.angle(ft[dom_idx])\n",
- " norm = np.linalg.norm(w)\n",
- " color = freq_colors[dom_idx]\n",
- " alpha = 0.9 if not highlight_freqs or dom_idx in highlight_freqs else 0.1\n",
- " ax.plot([phase, phase], [0, norm], color=color, linewidth=2, alpha=alpha)\n",
- " ax.scatter(phase, norm, color=color, s=40, alpha=alpha)\n",
- " angles = np.arange(0, 360, 45)\n",
- " # ax.set_thetagrids(angles, [f\"{a}°\" if a in [45,135,225,315] else \"\" for a in angles])\n",
- " ax.set_thetagrids(angles, [\"\" for a in angles])\n",
- " ax.set_yticklabels([])\n",
- " ax.spines['polar'].set_linewidth(2)\n",
- " if row == 0: ax.set_title(\"Phase\", fontsize=24)\n",
- "\n",
- "# Shift polar plots left to reduce whitespace\n",
- "for row in range(num_rows):\n",
- " ax = axes[row, 2]\n",
- " pos = ax.get_position()\n",
- " ax.set_position([pos.x0 - 0.155, pos.y0, pos.width, pos.height])\n",
- "\n",
- "plt.savefig(\"W-weights.pdf\", bbox_inches='tight')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9f5f56f9-7055-4056-9a18-7d91b3be50f8",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "rubiks",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/paper_figures.ipynb b/notebooks/paper_figures.ipynb
deleted file mode 100644
index 696c450..0000000
--- a/notebooks/paper_figures.ipynb
+++ /dev/null
@@ -1,268 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "249dbef8",
- "metadata": {},
- "source": [
- "# Create Figures for Group AGF Paper"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "58a60e4e",
- "metadata": {},
- "source": [
- "For each group, we should have 4 figures that make up a panel. We can organize these figures in Adobe Illustrator for the final figures.\n",
- "- Loss plot with clear drops\n",
- "- Model output power over epoch\n",
- "- Weights over epoch (ideally one weight per irreps, but for MNIST just choose ~5.)\n",
- "- Output prediction vs target over epoch"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3fd72af4",
- "metadata": {},
- "source": [
- "Action Plan:\n",
- "- Load Model\n",
- "- recreate power and loss figures\n",
- "\n",
- "Weights figure\n",
- "- go through weights and label them with their predomonant frequency\n",
- "- Choose 1 weight per frequency, and plot these over time. In AI outline these weights with the same color as their power\n",
- "- in the loss plot, try to make the color of the line (starting at the drop) match the color of the power and the irreps\n",
- "\n",
- "Output figure\n",
- "- detect jumps in power\n",
- "- right after each jump, save epoch value for label\n",
- "- at that epoch, plot Y[i] and model[X[i]] output."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7ff57238",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "import os\n",
- "import torch.nn as nn\n",
- "import torch.optim as optim\n",
- "import shutil\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "from matplotlib.animation import FuncAnimation\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "from matplotlib.ticker import FuncFormatter\n",
- "from matplotlib.ticker import MaxNLocator\n",
- "\n",
- "import importlib\n",
- "import pickle\n",
- "\n",
- "import group_agf.binary_action_learning.models as models\n",
- "import group_agf.binary_action_learning.datasets as datasets\n",
- "import group_agf.binary_action_learning.power as power\n",
- "import group_agf.binary_action_learning.plot as plot\n",
- "\n",
- "from escnn.group import *"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d9fef4df",
- "metadata": {},
- "outputs": [],
- "source": [
- "# parameters\n",
- "config = {\n",
- " \"model_save_dir\": \"/tmp/adele/\",\n",
- " \"dataset_fraction\": 0.3, # fraction of the total dataset to train on\n",
- " \"group_name\": 'dihedral', # 'dihedral', 'cnxcn', 'octahedral'\n",
- " \"mnist_digit\": 4,\n",
- " \"group_n\": 4, # n in Dn [3, 4, 5]\n",
- " \"image_length\": 10, # length of one side of the square image patch\n",
- " # Learning Parameters\n",
- " \"seed\": 10,\n",
- " \"init_scale\": 1e-2,\n",
- " \"lr\": 0.001,\n",
- " \"mom\": 0.9,\n",
- " # Training parameters\n",
- " \"epochs\": 5000,\n",
- " \"verbose_interval\": 100,\n",
- " \"batch_size\": 128,\n",
- " \"frequencies_to_learn\": 3,\n",
- " \"run_start_time\": \"10-31_14-42-46\",\n",
- "}\n",
- "\n",
- "config[\"group\"] = DihedralGroup(config[\"group_n\"])\n",
- "config[\"group_size\"] = config[\"group\"].order()\n",
- "\n",
- "print(f\"Group name: {config['group_name']}, group size: {config['group_size']}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "aa742e82",
- "metadata": {},
- "outputs": [],
- "source": [
- "from pyexpat import model\n",
- "\n",
- "\n",
- "model_save_path = models.get_model_save_path(config)\n",
- "print(model_save_path)\n",
- "# model_save_path = '/tmp/adele/model_group_namedihedral_group_size8_frac0.3_init0.01_lr0.01_mom0.9_bs128_epochs5000_seed10_run_start10-31_14-25-32.pkl'\n",
- "# print(model_save_path)\n",
- "\n",
- "# Load training history and model parameters from the specified path\n",
- "with open(model_save_path, 'rb') as f:\n",
- " training_history = pickle.load(f)\n",
- "\n",
- "loss_history = training_history[\"loss_history\"]\n",
- "accuracy_history = training_history[\"accuracy_history\"]\n",
- "param_history = training_history[\"param_history\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "14393dea",
- "metadata": {},
- "outputs": [],
- "source": [
- "X, Y, template = datasets.load_dataset(config)\n",
- "template_power = power.GroupPower(template, config['group'])\n",
- "X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, device=None)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "78975e86",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Create the model instance (matching the one that was trained)\n",
- "model = models.TwoLayerNet(\n",
- " group_size=config['group_size'],\n",
- " hidden_size=None, # uses default if not provided\n",
- " nonlinearity='square', # adjust if needed according to training\n",
- " init_scale=config['init_scale'],\n",
- " output_scale=1.0\n",
- ").to(device)\n",
- "\n",
- "# Restore the model parameters from the last saved point in param_history\n",
- "final_params = param_history[-1]\n",
- "with torch.no_grad():\n",
- " model.U.copy_(final_params['U'])\n",
- " model.V.copy_(final_params['V'])\n",
- " model.W.copy_(final_params['W'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a1723c01",
- "metadata": {},
- "outputs": [],
- "source": [
- "power_over_training_plot, freq_colors = plot.plot_training_power_over_time(\n",
- " template_power, \n",
- " model, \n",
- " device, \n",
- " param_history, \n",
- " X, \n",
- " config['group_name'], \n",
- " save_path=None, \n",
- " show=False,\n",
- " return_freq_colors=True\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "89459141",
- "metadata": {},
- "outputs": [],
- "source": [
- "loss_plot = plot.plot_loss_curve(loss_history, template_power, show=False, freq_colors=freq_colors)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "86b2dd11",
- "metadata": {},
- "outputs": [],
- "source": [
- "irreps_plot = plot.plot_irreps(config['group'], config['group_name'], show=False)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9d8be8b8",
- "metadata": {},
- "outputs": [],
- "source": [
- "neuron_indices = list(range(config['group_size'] * config['frequencies_to_learn']))\n",
- "neuron_weights_plot = plot.plot_neuron_weights(\n",
- " config['group_name'],\n",
- " config['group'],\n",
- " model,\n",
- " config['group_size'],\n",
- " neuron_indices=neuron_indices,\n",
- " show=True\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "31eec909",
- "metadata": {},
- "outputs": [],
- "source": [
- "model_predictions_plot = plot.plot_model_outputs(config['group_name'], config['group_size'], model, X, Y, idx=13, step=1, show=True) \n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3a877866",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/rnn_gagf.ipynb b/notebooks/rnn_gagf.ipynb
deleted file mode 100644
index fb9f402..0000000
--- a/notebooks/rnn_gagf.ipynb
+++ /dev/null
@@ -1,2112 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "af291059",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import yaml\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from pathlib import Path\n",
- "import seaborn as sns\n",
- "\n",
- "def load_sweep_results_grid(sweep_dir: str, k_values: list, hidden_dims: list):\n",
- " \"\"\"\n",
- " Load sweep results and organize into a grid for heatmap visualization.\n",
- " \n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " \n",
- " Returns:\n",
- " grid: 2D numpy array with shape (len(hidden_dims), len(k_values))\n",
- " containing mean final train losses\n",
- " std_grid: 2D numpy array with standard deviations (if multiple seeds)\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- " \n",
- " # Initialize grids\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " \n",
- " # Load results for each experiment\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- " \n",
- " if not exp_dir.exists():\n",
- " print(f\"Warning: Experiment {exp_name} not found\")\n",
- " continue\n",
- " \n",
- " # Load experiment summary\n",
- " summary_file = exp_dir / \"experiment_summary.yaml\"\n",
- " if summary_file.exists():\n",
- " with open(summary_file, 'r') as f:\n",
- " summary = yaml.safe_load(f)\n",
- " \n",
- " # Get mean train loss\n",
- " if 'train_loss_stats' in summary:\n",
- " grid[i, j] = summary['train_loss_stats']['mean']\n",
- " std_grid[i, j] = summary['train_loss_stats']['std']\n",
- " else:\n",
- " print(f\"Warning: No train_loss_stats in {exp_name}\")\n",
- " else:\n",
- " print(f\"Warning: No summary file for {exp_name}\")\n",
- " \n",
- " return grid, std_grid\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "c433cb4d",
- "metadata": {},
- "source": [
- "## 1D Analysis Functions\n",
- "\n",
- "Analyze individual 1D experiments from the sweep with detailed power spectrum and neuron specialization plots.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "95df6861",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "from pathlib import Path\n",
- "from gagf.rnns.utils import (\n",
- " plot_prediction_power_spectrum_over_time_1d,\n",
- " plot_wout_neuron_specialization_1d,\n",
- " plot_model_predictions_over_time_1d,\n",
- " topk_template_freqs_1d,\n",
- ")\n",
- "from gagf.rnns.model import SequentialMLP\n",
- "\n",
- "def analyze_1d_experiment(sweep_dir, exp_name, seed=0, num_freqs_to_track=10):\n",
- " \"\"\"\n",
- " Analyze a single 1D experiment from the sweep.\n",
- " \n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " exp_name: Experiment name (e.g., \"k3_h360\")\n",
- " seed: Seed number to analyze\n",
- " num_freqs_to_track: Number of top frequencies to track\n",
- " \n",
- " Returns:\n",
- " Dictionary with analysis results\n",
- " \"\"\"\n",
- " # Setup paths\n",
- " sweep_path = Path(sweep_dir)\n",
- " exp_dir = sweep_path / exp_name / f\"seed_{seed}\"\n",
- " \n",
- " if not exp_dir.exists():\n",
- " print(f\"Experiment directory not found: {exp_dir}\")\n",
- " return None\n",
- " \n",
- " print(f\"Analyzing: {exp_name}, seed {seed}\")\n",
- " print(f\"Directory: {exp_dir}\")\n",
- " \n",
- " # Load config\n",
- " with open(exp_dir / \"config.yaml\", 'r') as f:\n",
- " config = yaml.safe_load(f)\n",
- " \n",
- " # Load template\n",
- " template = np.load(exp_dir / \"template.npy\")\n",
- " p = len(template)\n",
- " k = config['data']['k']\n",
- " hidden_dim = config['model']['hidden_dim']\n",
- " \n",
- " print(f\" p={p}, k={k}, hidden_dim={hidden_dim}\")\n",
- " \n",
- " # Load training history\n",
- " train_loss_hist = np.load(exp_dir / \"train_loss_history.npy\")\n",
- " param_hist = torch.load(exp_dir / \"param_history.pt\", map_location='cpu')\n",
- " \n",
- " # Create model\n",
- " device = 'cpu'\n",
- " template_torch = torch.tensor(template, dtype=torch.float32, device=device)\n",
- " model = SequentialMLP(\n",
- " p=p,\n",
- " d=hidden_dim,\n",
- " template=template_torch,\n",
- " k=k,\n",
- " init_scale=config['model']['init_scale'],\n",
- " return_all_outputs=config['model']['return_all_outputs'],\n",
- " ).to(device)\n",
- " \n",
- " # Generate evaluation data\n",
- " from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_1d\n",
- " X_data, Y_data, _ = build_modular_addition_sequence_dataset_1d(\n",
- " p, template, k,\n",
- " mode='sampled',\n",
- " num_samples=1000,\n",
- " return_all_outputs=config['model']['return_all_outputs'],\n",
- " )\n",
- " X_data_t = torch.tensor(X_data, dtype=torch.float32, device=device)\n",
- " Y_data_t = torch.tensor(Y_data, dtype=torch.float32, device=device)\n",
- " \n",
- " # Get tracked frequencies\n",
- " tracked_freqs = topk_template_freqs_1d(template, K=num_freqs_to_track)\n",
- " colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs)))\n",
- " \n",
- " # Checkpoints to analyze\n",
- " checkpoint_indices = [0, len(param_hist)//4, len(param_hist)//2, \n",
- " 3*len(param_hist)//4, len(param_hist)-1]\n",
- " \n",
- " # Plot 1: Power spectrum over time\n",
- " print(\"\\\\n Plotting power spectrum analysis...\")\n",
- " fig1, _, _, _ = plot_prediction_power_spectrum_over_time_1d(\n",
- " model, param_hist, X_data_t, Y_data_t, template, p,\n",
- " loss_history=train_loss_hist,\n",
- " num_freqs_to_track=num_freqs_to_track,\n",
- " num_samples=100,\n",
- " save_path=exp_dir / \"power_spectrum_analysis_1d.pdf\",\n",
- " show=True\n",
- " )\n",
- " \n",
- " # Plot 2: Model predictions over time\n",
- " print(\" Plotting predictions over time...\")\n",
- " fig2, _ = plot_model_predictions_over_time_1d(\n",
- " model, param_hist, X_data_t, Y_data_t, p,\n",
- " steps=checkpoint_indices,\n",
- " save_path=exp_dir / \"predictions_over_time_1d.pdf\",\n",
- " show=True\n",
- " )\n",
- " \n",
- " # Plot 3: W_out neuron specialization\n",
- " print(\" Plotting W_out neuron specialization...\")\n",
- " figs3 = plot_wout_neuron_specialization_1d(\n",
- " param_hist, tracked_freqs, colors, p,\n",
- " steps=checkpoint_indices,\n",
- " dead_thresh_l2=0.25,\n",
- " save_dir=exp_dir,\n",
- " show=True\n",
- " )\n",
- " \n",
- " print(\"\\\\n ✓ Analysis complete!\")\n",
- " \n",
- " return {\n",
- " 'config': config,\n",
- " 'template': template,\n",
- " 'train_loss': train_loss_hist,\n",
- " 'tracked_freqs': tracked_freqs,\n",
- " }\n",
- "\n",
- "# Example usage:\n",
- "# sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251202_XXXXXX\"\n",
- "# result = analyze_1d_experiment(sweep_dir, \"k3_h360\", seed=0)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7bc3db90",
- "metadata": {},
- "source": [
- "# Analyze RNNs trained on GAGF sequential task"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "11fb7c9b",
- "metadata": {},
- "source": [
- "## Set up"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "43581ce4",
- "metadata": {},
- "outputs": [],
- "source": [
- "# autoreload\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "# jupyter black formatter\n",
- "%load_ext jupyter_black\n",
- "\n",
- "import subprocess\n",
- "import os\n",
- "import sys\n",
- "\n",
- "gitroot_path = subprocess.check_output(\n",
- " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n",
- ").strip()\n",
- "\n",
- "os.chdir(gitroot_path)\n",
- "print(\"Working directory: \", os.getcwd())\n",
- "\n",
- "if gitroot_path not in sys.path:\n",
- " sys.path.insert(0, gitroot_path)\n",
- "print(\"Directory added to path: \", gitroot_path)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f0407a17",
- "metadata": {},
- "source": [
- "## Sequence-to-sequence sweep across different values of k (sequence length)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "070e8c55",
- "metadata": {},
- "source": [
- "### Loss curves"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cf050abb",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from pathlib import Path\n",
- "from typing import Dict, List, Optional\n",
- "\n",
- "\n",
- "def get_sweep_experiments(sweep_dir: str) -> List[str]:\n",
- " \"\"\"\n",
- " Get all experiment names from a sweep directory.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- "\n",
- " Returns:\n",
- " List of experiment names (subdirectories with seed_0)\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- " experiments = []\n",
- "\n",
- " for item in sweep_path.iterdir():\n",
- " if (\n",
- " item.is_dir()\n",
- " and not item.name.startswith(\".\")\n",
- " and item.name not in [\"configs\"]\n",
- " ):\n",
- " # Check if it has a seed_0 subdirectory\n",
- " if (item / \"seed_0\").exists():\n",
- " experiments.append(item.name)\n",
- "\n",
- " return sorted(experiments)\n",
- "\n",
- "\n",
- "def load_experiment_losses(\n",
- " sweep_dir: str, experiment_name: str, seed: int = 0\n",
- ") -> Dict[str, np.ndarray]:\n",
- " \"\"\"\n",
- " Load training and validation loss histories for an experiment.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " experiment_name: Name of the experiment subdirectory\n",
- " seed: Seed number (default: 0)\n",
- "\n",
- " Returns:\n",
- " Dictionary with 'train' and 'val' loss arrays (if they exist)\n",
- " \"\"\"\n",
- " exp_path = Path(sweep_dir) / experiment_name / f\"seed_{seed}\"\n",
- " losses = {}\n",
- "\n",
- " train_loss_path = exp_path / \"train_loss_history.npy\"\n",
- " if train_loss_path.exists():\n",
- " losses[\"train\"] = np.load(train_loss_path)\n",
- "\n",
- " val_loss_path = exp_path / \"val_loss_history.npy\"\n",
- " if val_loss_path.exists():\n",
- " losses[\"val\"] = np.load(val_loss_path)\n",
- "\n",
- " return losses\n",
- "\n",
- "\n",
- "def load_all_sweep_losses(\n",
- " sweep_dir: str, seed: int = 0\n",
- ") -> Dict[str, Dict[str, np.ndarray]]:\n",
- " \"\"\"\n",
- " Load loss histories for all experiments in a sweep.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " seed: Seed number (default: 0)\n",
- "\n",
- " Returns:\n",
- " Dictionary mapping experiment names to their loss dictionaries\n",
- " \"\"\"\n",
- " experiments = get_sweep_experiments(sweep_dir)\n",
- " all_losses = {}\n",
- "\n",
- " for exp_name in experiments:\n",
- " all_losses[exp_name] = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- "\n",
- " return all_losses\n",
- "\n",
- "\n",
- "def remove_outliers_local(loss_history, window=10, threshold=3.0):\n",
- " \"\"\"\n",
- " Replace outliers with local median if they deviate too much.\n",
- "\n",
- " Args:\n",
- " loss_history: Array of loss values\n",
- " window: Window size for local statistics\n",
- " threshold: How many local standard deviations to consider an outlier\n",
- "\n",
- " Returns:\n",
- " Tuple of (cleaned loss history, whether any outliers were found)\n",
- " \"\"\"\n",
- " cleaned = loss_history.copy()\n",
- " half_window = window // 2\n",
- " outliers_found = False\n",
- "\n",
- " for i in range(len(loss_history)):\n",
- " start = max(0, i - half_window)\n",
- " end = min(len(loss_history), i + half_window + 1)\n",
- " local_window = loss_history[start:end]\n",
- "\n",
- " local_median = np.median(local_window)\n",
- " local_std = np.std(local_window)\n",
- "\n",
- " # If the value is too far from local median, replace it\n",
- " if abs(loss_history[i] - local_median) > threshold * local_std:\n",
- " cleaned[i] = local_median\n",
- " outliers_found = True\n",
- "\n",
- " return cleaned, outliers_found"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6b5653c8",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Set up your sweep directory\n",
- "sweep_dir = \"/home/facosta/group-agf/sweeps/seq_seq_sweep_20251113_120513\"\n",
- "\n",
- "# Get all experiments in the sweep\n",
- "experiments = get_sweep_experiments(sweep_dir)\n",
- "print(f\"Found experiments: {experiments}\")\n",
- "\n",
- "# Load losses for all experiments\n",
- "all_losses = load_all_sweep_losses(sweep_dir)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fa6b246a",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_loss_comparison(\n",
- " sweep_dir: str,\n",
- " experiments: Optional[List[str]] = None,\n",
- " loss_type: str = \"train\",\n",
- " log_scale: bool = True,\n",
- " figsize: tuple = (10, 6),\n",
- " seed: int = 0,\n",
- " remove_outliers: bool = False,\n",
- " outlier_window: int = 10,\n",
- " outlier_threshold: float = 3.0,\n",
- " template_2d: Optional[np.ndarray] = None,\n",
- " p1: Optional[int] = None,\n",
- " p2: Optional[int] = None,\n",
- " show_theory_bands: bool = True,\n",
- " num_theory_lines: Optional[int] = None,\n",
- " color_by_k: bool = True,\n",
- " cmap: str = \"viridis\",\n",
- "):\n",
- " \"\"\"\n",
- " Plot and compare loss curves from multiple experiments.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " experiments: List of experiment names to plot (None = all experiments)\n",
- " loss_type: 'train' or 'val'\n",
- " log_scale: Whether to use log scale for both axes\n",
- " figsize: Figure size tuple\n",
- " seed: Seed number (default: 0)\n",
- " remove_outliers: Whether to remove outliers using local outlier replacement\n",
- " outlier_window: Window size for outlier detection (default: 10)\n",
- " outlier_threshold: Threshold in standard deviations for outlier detection (default: 3.0)\n",
- " template_2d: Optional 2D template array for computing theory lines\n",
- " p1: First dimension of template (required if template_2d is provided)\n",
- " p2: Second dimension of template (required if template_2d is provided)\n",
- " show_theory_bands: Whether to show colored bands between theory lines (default: True)\n",
- " num_theory_lines: Number of theory lines to show (default: None = show all)\n",
- " color_by_k: Whether to color lines by k value (default: True)\n",
- " cmap: Colormap name for k-based coloring (default: 'viridis')\n",
- " \"\"\"\n",
- " if experiments is None:\n",
- " experiments = get_sweep_experiments(sweep_dir)\n",
- "\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- "\n",
- " # Compute theory lines if template is provided\n",
- " theory_levels = None\n",
- " if template_2d is not None:\n",
- " if p1 is None or p2 is None:\n",
- " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n",
- "\n",
- " # Import the helper function (assuming it's in utils.py)\n",
- " from gagf.rnns.utils import get_power_2d_adele\n",
- "\n",
- " # Compute power spectrum of template\n",
- " _, _, power = get_power_2d_adele(template_2d)\n",
- " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- "\n",
- " # Theory levels (cumulative tail sums)\n",
- " alpha_values = np.array(\n",
- " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n",
- " )\n",
- " coef = 1.0 / (p1 * p2)\n",
- " theory_levels = coef * alpha_values # strictly decreasing\n",
- "\n",
- " # Limit number of lines if specified\n",
- " if num_theory_lines is not None:\n",
- " theory_levels = theory_levels[: num_theory_lines + 1]\n",
- "\n",
- " # Generate colors for bands\n",
- " n_bands = len(theory_levels) - 1\n",
- " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n",
- "\n",
- " # Draw colored bands between theory lines\n",
- " if show_theory_bands and n_bands > 0:\n",
- " for i in range(n_bands):\n",
- " y_top = theory_levels[i]\n",
- " y_bot = theory_levels[i + 1]\n",
- " ax.axhspan(\n",
- " y_bot,\n",
- " y_top,\n",
- " facecolor=colors[i % len(colors)],\n",
- " alpha=0.15,\n",
- " zorder=-3,\n",
- " )\n",
- "\n",
- " # Draw the black theory lines\n",
- " for y in theory_levels:\n",
- " ax.axhline(\n",
- " y=y,\n",
- " color=\"black\",\n",
- " linestyle=\"--\",\n",
- " linewidth=1.5,\n",
- " alpha=0.7,\n",
- " zorder=-2,\n",
- " label=\"_nolegend_\",\n",
- " )\n",
- "\n",
- " # Extract k values and set up colormap\n",
- " k_values = {}\n",
- " if color_by_k:\n",
- " import re\n",
- "\n",
- " for exp_name in experiments:\n",
- " # Try to extract k value from experiment name (e.g., \"k_2_seqseq\" -> 2)\n",
- " match = re.search(r\"k[_\\s]*(\\d+)\", exp_name, re.IGNORECASE)\n",
- " if match:\n",
- " k_values[exp_name] = int(match.group(1))\n",
- "\n",
- " if k_values:\n",
- " k_min = min(k_values.values())\n",
- " k_max = max(k_values.values())\n",
- " norm = plt.cm.colors.Normalize(vmin=k_min, vmax=k_max)\n",
- " colormap = plt.cm.get_cmap(cmap)\n",
- " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n",
- "\n",
- " # Plot loss curves\n",
- " for exp_name in experiments:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " loss_history = losses[loss_type]\n",
- "\n",
- " # Apply outlier removal if requested\n",
- " if remove_outliers:\n",
- " loss_history, outliers_found = remove_outliers_local(\n",
- " loss_history, window=outlier_window, threshold=outlier_threshold\n",
- " )\n",
- " if outliers_found:\n",
- " print(f\"Outliers from {exp_name} removed for plot\")\n",
- "\n",
- " # Determine color\n",
- " if color_by_k and exp_name in k_values:\n",
- " color = scalar_map.to_rgba(k_values[exp_name])\n",
- " else:\n",
- " color = None # Use default color cycle\n",
- "\n",
- " ax.plot(loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color)\n",
- "\n",
- " ax.set_xlabel(\"Step\", fontsize=14)\n",
- " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n",
- " title = f\"{loss_type.capitalize()} Loss Comparison - {Path(sweep_dir).name}\"\n",
- " if remove_outliers:\n",
- " title += \" (outliers removed)\"\n",
- " ax.set_title(title, fontsize=14)\n",
- " ax.legend(fontsize=10)\n",
- " ax.grid(True, alpha=0.3)\n",
- "\n",
- " if log_scale:\n",
- " ax.set_xscale(\"log\")\n",
- " ax.set_yscale(\"log\")\n",
- "\n",
- " # Add colorbar if coloring by k\n",
- " if color_by_k and k_values:\n",
- " cbar = plt.colorbar(scalar_map, ax=ax, label=\"k (sequence length)\", pad=0.02)\n",
- " cbar.ax.tick_params(labelsize=10)\n",
- "\n",
- " plt.tight_layout()\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a8e3f5e3",
- "metadata": {},
- "outputs": [],
- "source": [
- "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "plt.imshow(template_2d)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dd4bf1a8",
- "metadata": {},
- "outputs": [],
- "source": [
- "# First, load the template from one of your experiments\n",
- "# (assuming they all use the same template)\n",
- "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "# Plot with theory lines\n",
- "fig, ax, theory_levels = plot_loss_comparison(\n",
- " sweep_dir,\n",
- " template_2d=template_2d,\n",
- " p1=p1,\n",
- " p2=p2,\n",
- " remove_outliers=True,\n",
- " num_theory_lines=10, # Show first 10 theory lines\n",
- " log_scale=True,\n",
- " cmap=\"viridis\",\n",
- ")\n",
- "\n",
- "# Print the theory levels (plateau values)\n",
- "print(\"Theory plateau levels:\")\n",
- "for i, level in enumerate(theory_levels):\n",
- " print(f\" Plateau {i}: {level:.6e}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ead6933c",
- "metadata": {},
- "source": [
- "### Time to reach plateau"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4e1d02b4",
- "metadata": {},
- "outputs": [],
- "source": [
- "def calculate_time_to_plateau(\n",
- " sweep_dir: str,\n",
- " template_2d: np.ndarray,\n",
- " p1: int,\n",
- " p2: int,\n",
- " target_plateau_idx: int = 1,\n",
- " loss_type: str = \"train\",\n",
- " seed: int = 0,\n",
- " tolerance: float = 1.1,\n",
- " experiments: Optional[List[str]] = None,\n",
- ") -> Dict[str, int]:\n",
- " \"\"\"\n",
- " Calculate the step at which each experiment's loss reaches a target plateau.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " template_2d: 2D template array for computing theory lines\n",
- " p1, p2: Dimensions of template\n",
- " target_plateau_idx: Index of target plateau (1 = first drop, 2 = second drop, etc.)\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- " tolerance: Multiplier for plateau threshold (loss must be <= tolerance * plateau_level)\n",
- " experiments: List of experiment names (None = all experiments)\n",
- "\n",
- " Returns:\n",
- " Dictionary mapping experiment names to step numbers\n",
- " \"\"\"\n",
- " from gagf.rnns.utils import get_power_2d_adele\n",
- "\n",
- " # Compute theory levels\n",
- " _, _, power = get_power_2d_adele(template_2d)\n",
- " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- " alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n",
- " coef = 1.0 / (p1 * p2)\n",
- " theory_levels = coef * alpha_values\n",
- "\n",
- " if target_plateau_idx >= len(theory_levels):\n",
- " raise ValueError(\n",
- " f\"target_plateau_idx {target_plateau_idx} exceeds available plateaus ({len(theory_levels)})\"\n",
- " )\n",
- "\n",
- " target_level = theory_levels[target_plateau_idx]\n",
- " threshold = tolerance * target_level\n",
- "\n",
- " if experiments is None:\n",
- " experiments = get_sweep_experiments(sweep_dir)\n",
- "\n",
- " times_to_plateau = {}\n",
- "\n",
- " for exp_name in experiments:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " loss_history = losses[loss_type]\n",
- "\n",
- " # Find first step where loss drops below threshold\n",
- " crossing_indices = np.where(loss_history <= threshold)[0]\n",
- "\n",
- " if len(crossing_indices) > 0:\n",
- " times_to_plateau[exp_name] = crossing_indices[0]\n",
- " else:\n",
- " times_to_plateau[exp_name] = None # Never reached\n",
- "\n",
- " return times_to_plateau, theory_levels\n",
- "\n",
- "\n",
- "def plot_time_to_plateau(\n",
- " sweep_dir: str,\n",
- " template_2d: np.ndarray,\n",
- " p1: int,\n",
- " p2: int,\n",
- " target_plateau_idx: int = 1,\n",
- " loss_type: str = \"train\",\n",
- " seed: int = 0,\n",
- " tolerance: float = 1.1,\n",
- " experiments: Optional[List[str]] = None,\n",
- " figsize: tuple = (10, 6),\n",
- " sort_by: str = \"time\",\n",
- "):\n",
- " \"\"\"\n",
- " Plot the time to reach a target plateau for different experiments.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " template_2d: 2D template array\n",
- " p1, p2: Template dimensions\n",
- " target_plateau_idx: Which plateau to measure (1 = first drop)\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- " tolerance: Multiplier for plateau threshold\n",
- " experiments: List of experiment names (None = all)\n",
- " figsize: Figure size\n",
- " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n",
- " \"\"\"\n",
- " times, theory_levels = calculate_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx,\n",
- " loss_type,\n",
- " seed,\n",
- " tolerance,\n",
- " experiments,\n",
- " )\n",
- "\n",
- " # Filter out experiments that never reached the plateau\n",
- " reached = {k: v for k, v in times.items() if v is not None}\n",
- " not_reached = [k for k, v in times.items() if v is None]\n",
- "\n",
- " if not reached:\n",
- " print(\"No experiments reached the target plateau!\")\n",
- " return\n",
- "\n",
- " # Sort experiments\n",
- " if sort_by == \"time\":\n",
- " sorted_items = sorted(reached.items(), key=lambda x: x[1])\n",
- " else: # alphabetical\n",
- " sorted_items = sorted(reached.items(), key=lambda x: x[0])\n",
- "\n",
- " exp_names = [item[0] for item in sorted_items]\n",
- " steps = [item[1] for item in sorted_items]\n",
- "\n",
- " # Create bar plot\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n",
- "\n",
- " # Color bars by value (gradient)\n",
- " colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(steps)))\n",
- " for bar, color in zip(bars, colors):\n",
- " bar.set_color(color)\n",
- "\n",
- " ax.set_xticks(range(len(exp_names)))\n",
- " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n",
- " ax.set_xlabel(\"Experiment\", fontsize=12)\n",
- " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n",
- " ax.set_title(\n",
- " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n",
- " fontsize=13,\n",
- " )\n",
- " ax.grid(True, alpha=0.3, axis=\"y\")\n",
- "\n",
- " # Add value labels on bars\n",
- " for i, (bar, step) in enumerate(zip(bars, steps)):\n",
- " height = bar.get_height()\n",
- " ax.text(\n",
- " bar.get_x() + bar.get_width() / 2.0,\n",
- " height,\n",
- " f\"{int(step):,}\",\n",
- " ha=\"center\",\n",
- " va=\"bottom\",\n",
- " fontsize=9,\n",
- " )\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " # Print summary\n",
- " print(\n",
- " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n",
- " )\n",
- " print(\"-\" * 60)\n",
- " for name, step in sorted_items:\n",
- " print(f\" {name:30s}: {step:8,} steps\")\n",
- "\n",
- " if not_reached:\n",
- " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n",
- " for name in not_reached:\n",
- " print(f\" - {name}\")\n",
- "\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, times, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b5b266d4",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_time_to_plateau(\n",
- " sweep_dir: str,\n",
- " template_2d: np.ndarray,\n",
- " p1: int,\n",
- " p2: int,\n",
- " target_plateau_idx: int = 1,\n",
- " loss_type: str = \"train\",\n",
- " seed: int = 0,\n",
- " tolerance: float = 1.1,\n",
- " experiments: Optional[List[str]] = None,\n",
- " figsize: tuple = (10, 6),\n",
- " sort_by: str = \"time\",\n",
- " color_by_k: bool = True,\n",
- " cmap: str = \"viridis\",\n",
- "):\n",
- " \"\"\"\n",
- " Plot the time to reach a target plateau for different experiments.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " template_2d: 2D template array\n",
- " p1, p2: Template dimensions\n",
- " target_plateau_idx: Which plateau to measure (1 = first drop)\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- " tolerance: Multiplier for plateau threshold\n",
- " experiments: List of experiment names (None = all)\n",
- " figsize: Figure size\n",
- " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n",
- " color_by_k: Whether to color bars by k value (default: True)\n",
- " cmap: Colormap name for k-based coloring (default: 'viridis')\n",
- " \"\"\"\n",
- " times, theory_levels = calculate_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx,\n",
- " loss_type,\n",
- " seed,\n",
- " tolerance,\n",
- " experiments,\n",
- " )\n",
- "\n",
- " # Filter out experiments that never reached the plateau\n",
- " reached = {k: v for k, v in times.items() if v is not None}\n",
- " not_reached = [k for k, v in times.items() if v is None]\n",
- "\n",
- " if not reached:\n",
- " print(\"No experiments reached the target plateau!\")\n",
- " return\n",
- "\n",
- " # Sort experiments\n",
- " if sort_by == \"time\":\n",
- " sorted_items = sorted(reached.items(), key=lambda x: x[1])\n",
- " else: # alphabetical\n",
- " sorted_items = sorted(reached.items(), key=lambda x: x[0])\n",
- "\n",
- " exp_names = [item[0] for item in sorted_items]\n",
- " steps = [item[1] for item in sorted_items]\n",
- "\n",
- " # Extract k values and set up colormap\n",
- " k_values = []\n",
- " if color_by_k:\n",
- " import re\n",
- "\n",
- " for exp_name in exp_names:\n",
- " # Try to extract k value from experiment name (e.g., \"k_2_seqseq\" -> 2)\n",
- " match = re.search(r\"k[_\\s]*(\\d+)\", exp_name, re.IGNORECASE)\n",
- " if match:\n",
- " k_values.append(int(match.group(1)))\n",
- " else:\n",
- " k_values.append(None)\n",
- "\n",
- " # Check if we have valid k values\n",
- " valid_k_values = [k for k in k_values if k is not None]\n",
- " if valid_k_values:\n",
- " k_min = min(valid_k_values)\n",
- " k_max = max(valid_k_values)\n",
- " norm = plt.cm.colors.Normalize(vmin=k_min, vmax=k_max)\n",
- " colormap = plt.cm.get_cmap(cmap)\n",
- " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n",
- " else:\n",
- " color_by_k = False # Fall back if no k values found\n",
- "\n",
- " # Create bar plot\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n",
- "\n",
- " # Color bars\n",
- " if color_by_k and valid_k_values:\n",
- " for bar, k_val in zip(bars, k_values):\n",
- " if k_val is not None:\n",
- " bar.set_color(scalar_map.to_rgba(k_val))\n",
- " else:\n",
- " bar.set_color(\"gray\") # Fallback color\n",
- " else:\n",
- " # Color bars by position (gradient from blue to yellow)\n",
- " colors = plt.cm.viridis(np.linspace(0.2, 0.9, len(steps)))\n",
- " for bar, color in zip(bars, colors):\n",
- " bar.set_color(color)\n",
- "\n",
- " ax.set_xticks(range(len(exp_names)))\n",
- " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n",
- " ax.set_xlabel(\"Experiment\", fontsize=12)\n",
- " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n",
- " ax.set_title(\n",
- " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n",
- " fontsize=13,\n",
- " )\n",
- " ax.grid(True, alpha=0.3, axis=\"y\")\n",
- "\n",
- " # Add value labels on bars\n",
- " for i, (bar, step) in enumerate(zip(bars, steps)):\n",
- " height = bar.get_height()\n",
- " ax.text(\n",
- " bar.get_x() + bar.get_width() / 2.0,\n",
- " height,\n",
- " f\"{int(step):,}\",\n",
- " ha=\"center\",\n",
- " va=\"bottom\",\n",
- " fontsize=9,\n",
- " )\n",
- "\n",
- " # Add colorbar if coloring by k\n",
- " if color_by_k and valid_k_values:\n",
- " cbar = plt.colorbar(scalar_map, ax=ax, label=\"k (sequence length)\", pad=0.02)\n",
- " cbar.ax.tick_params(labelsize=10)\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " # Print summary\n",
- " print(\n",
- " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n",
- " )\n",
- " print(\"-\" * 60)\n",
- " for name, step in sorted_items:\n",
- " print(f\" {name:30s}: {step:8,} steps\")\n",
- "\n",
- " if not_reached:\n",
- " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n",
- " for name in not_reached:\n",
- " print(f\" - {name}\")\n",
- "\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, times, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b015323f",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load template\n",
- "template_path = os.path.join(sweep_dir, \"k_2_seqseq\", \"seed_0\", \"template.npy\")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "# get the times without plotting\n",
- "times_dict, theory_levels = calculate_time_to_plateau(\n",
- " sweep_dir, template_2d, p1, p2, target_plateau_idx=1\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "eb63ea4a",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Plot time to first drop (plateau index 1)\n",
- "fig, ax, times, levels = plot_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx=1, # First drop\n",
- " tolerance=1.1, # Loss must be <= 1.1 * theory_level\n",
- " sort_by=\"name\",\n",
- ")\n",
- "\n",
- "# Plot time to second drop (plateau index 2)\n",
- "plot_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx=2,\n",
- " tolerance=1.05, # Second drop\n",
- " cmap=\"viridis\",\n",
- " sort_by=\"name\",\n",
- ")\n",
- "\n",
- "# Plot time to third drop (plateau index 3)\n",
- "plot_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx=3,\n",
- " tolerance=1.05, # Second drop\n",
- " cmap=\"viridis\",\n",
- " sort_by=\"name\",\n",
- ")\n",
- "\n",
- "# Plot time to fourth drop (plateau index 4)\n",
- "plot_time_to_plateau(\n",
- " sweep_dir,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx=4,\n",
- " tolerance=1.05, # Second drop\n",
- " cmap=\"viridis\",\n",
- " sort_by=\"name\",\n",
- ");"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5536cdb2",
- "metadata": {},
- "source": [
- "## Sequence-to-one sweep across different values of k (sequence length)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "063d9df4",
- "metadata": {},
- "source": [
- "- 1st sweep dir: \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\" (k=4, 5)\n",
- "- 2nd sweep dir: \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\" (k=2, 3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d96f443e",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_losses_from_multiple_sweeps(\n",
- " sweep_experiments: Dict[str, List[str]], loss_type: str = \"train\", seed: int = 0\n",
- ") -> Dict[str, Dict[str, np.ndarray]]:\n",
- " \"\"\"\n",
- " Load loss histories from experiments across multiple sweeps.\n",
- "\n",
- " Args:\n",
- " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n",
- " e.g., {sweep_dir_1: [\"exp1\", \"exp2\"], sweep_dir_2: [\"exp3\", \"exp4\"]}\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- "\n",
- " Returns:\n",
- " Dictionary mapping experiment names to their loss dictionaries\n",
- " Note: Experiment names must be unique across sweeps\n",
- " \"\"\"\n",
- " all_losses = {}\n",
- "\n",
- " for sweep_dir, exp_names in sweep_experiments.items():\n",
- " for exp_name in exp_names:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " # Store with experiment name as key\n",
- " all_losses[exp_name] = losses\n",
- " else:\n",
- " print(\n",
- " f\"Warning: No {loss_type} loss found for {exp_name} in {sweep_dir}\"\n",
- " )\n",
- "\n",
- " return all_losses\n",
- "\n",
- "\n",
- "def create_color_mapping_from_multiple_sweeps(\n",
- " sweep_experiments: Dict[str, List[str]],\n",
- " parameter_path: str,\n",
- " cmap: str = \"viridis\",\n",
- " seed: int = 0,\n",
- " log_scale: bool = False,\n",
- ") -> tuple:\n",
- " \"\"\"\n",
- " Create color mapping for experiments across multiple sweeps.\n",
- "\n",
- " Args:\n",
- " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n",
- " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size')\n",
- " cmap: Colormap name\n",
- " seed: Seed number\n",
- " log_scale: Whether to use logarithmic scale for color mapping\n",
- "\n",
- " Returns:\n",
- " Tuple of (color_mapping dict, scalar_map, param_values dict)\n",
- " \"\"\"\n",
- " param_values = {}\n",
- "\n",
- " # Extract parameters from all experiments across all sweeps\n",
- " for sweep_dir, exp_names in sweep_experiments.items():\n",
- " for exp_name in exp_names:\n",
- " value = extract_config_parameter(sweep_dir, exp_name, parameter_path, seed)\n",
- " if value is not None:\n",
- " param_values[exp_name] = value\n",
- "\n",
- " if not param_values:\n",
- " print(f\"Warning: Could not extract '{parameter_path}' from any experiments\")\n",
- " return {}, None, {}\n",
- "\n",
- " # Create color mapping\n",
- " values = list(param_values.values())\n",
- " v_min = min(values)\n",
- " v_max = max(values)\n",
- "\n",
- " # Use log or linear normalization\n",
- " if log_scale:\n",
- " if v_min <= 0:\n",
- " print(\n",
- " f\"Warning: log_scale requested but found non-positive values (min={v_min}). Using linear scale.\"\n",
- " )\n",
- " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n",
- " else:\n",
- " norm = plt.cm.colors.LogNorm(vmin=v_min, vmax=v_max)\n",
- " else:\n",
- " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n",
- "\n",
- " colormap = plt.cm.get_cmap(cmap)\n",
- "\n",
- " color_mapping = {}\n",
- " for exp_name, value in param_values.items():\n",
- " color_mapping[exp_name] = colormap(norm(value))\n",
- "\n",
- " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n",
- "\n",
- " return color_mapping, scalar_map, param_values\n",
- "\n",
- "\n",
- "def plot_loss_comparison_multi_sweep(\n",
- " sweep_experiments: Dict[str, List[str]],\n",
- " loss_type: str = \"train\",\n",
- " log_scale: bool = True,\n",
- " figsize: tuple = (10, 6),\n",
- " seed: int = 0,\n",
- " remove_outliers: bool = False,\n",
- " outlier_window: int = 10,\n",
- " outlier_threshold: float = 3.0,\n",
- " template_2d: Optional[np.ndarray] = None,\n",
- " p1: Optional[int] = None,\n",
- " p2: Optional[int] = None,\n",
- " show_theory_bands: bool = True,\n",
- " num_theory_lines: Optional[int] = None,\n",
- " color_mapping: Optional[Dict[str, tuple]] = None,\n",
- " colorbar_label: Optional[str] = None,\n",
- " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n",
- "):\n",
- " \"\"\"\n",
- " Plot and compare loss curves from experiments across multiple sweeps.\n",
- "\n",
- " Args:\n",
- " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n",
- " e.g., {sweep_dir_1: [\"exp1\", \"exp2\"], sweep_dir_2: [\"exp3\"]}\n",
- " loss_type: 'train' or 'val'\n",
- " log_scale: Whether to use log scale for both axes\n",
- " figsize: Figure size tuple\n",
- " seed: Seed number\n",
- " remove_outliers: Whether to remove outliers\n",
- " outlier_window: Window size for outlier detection\n",
- " outlier_threshold: Threshold for outlier detection\n",
- " template_2d: Optional 2D template array for computing theory lines\n",
- " p1, p2: Template dimensions\n",
- " show_theory_bands: Whether to show colored bands between theory lines\n",
- " num_theory_lines: Number of theory lines to show\n",
- " color_mapping: Dictionary mapping experiment names to RGBA colors\n",
- " colorbar_label: Label for colorbar\n",
- " scalar_map: ScalarMappable for colorbar\n",
- " \"\"\"\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- "\n",
- " # Compute theory lines if template is provided\n",
- " theory_levels = None\n",
- " if template_2d is not None:\n",
- " if p1 is None or p2 is None:\n",
- " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n",
- "\n",
- " from gagf.rnns.utils import get_power_2d_adele\n",
- "\n",
- " _, _, power = get_power_2d_adele(template_2d)\n",
- " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- "\n",
- " alpha_values = np.array(\n",
- " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n",
- " )\n",
- " coef = 1.0 / (p1 * p2)\n",
- " theory_levels = coef * alpha_values\n",
- "\n",
- " if num_theory_lines is not None:\n",
- " theory_levels = theory_levels[: num_theory_lines + 1]\n",
- "\n",
- " n_bands = len(theory_levels) - 1\n",
- " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n",
- "\n",
- " if show_theory_bands and n_bands > 0:\n",
- " for i in range(n_bands):\n",
- " y_top = theory_levels[i]\n",
- " y_bot = theory_levels[i + 1]\n",
- " ax.axhspan(\n",
- " y_bot,\n",
- " y_top,\n",
- " facecolor=colors[i % len(colors)],\n",
- " alpha=0.15,\n",
- " zorder=-3,\n",
- " )\n",
- "\n",
- " for y in theory_levels:\n",
- " ax.axhline(\n",
- " y=y,\n",
- " color=\"black\",\n",
- " linestyle=\"--\",\n",
- " linewidth=1.5,\n",
- " alpha=0.7,\n",
- " zorder=-2,\n",
- " label=\"_nolegend_\",\n",
- " )\n",
- "\n",
- " # Plot loss curves from all sweeps\n",
- " for sweep_dir, exp_names in sweep_experiments.items():\n",
- " for exp_name in exp_names:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " loss_history = losses[loss_type]\n",
- "\n",
- " if remove_outliers:\n",
- " loss_history, outliers_found = remove_outliers_local(\n",
- " loss_history, window=outlier_window, threshold=outlier_threshold\n",
- " )\n",
- " if outliers_found:\n",
- " print(f\"Outliers from {exp_name} removed for plot\")\n",
- "\n",
- " # Determine color\n",
- " if color_mapping and exp_name in color_mapping:\n",
- " color = color_mapping[exp_name]\n",
- " else:\n",
- " color = None\n",
- "\n",
- " ax.plot(\n",
- " loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color\n",
- " )\n",
- "\n",
- " ax.set_xlabel(\"Step\", fontsize=14)\n",
- " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n",
- " ax.set_title(\n",
- " f\"{loss_type.capitalize()} Loss Comparison (Multiple Sweeps)\", fontsize=14\n",
- " )\n",
- " ax.legend(fontsize=10)\n",
- " ax.grid(True, alpha=0.3)\n",
- "\n",
- " if log_scale:\n",
- " ax.set_xscale(\"log\")\n",
- " ax.set_yscale(\"log\")\n",
- "\n",
- " if color_mapping and scalar_map is not None:\n",
- " label = colorbar_label if colorbar_label else \"Parameter Value\"\n",
- " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n",
- " cbar.ax.tick_params(labelsize=10)\n",
- "\n",
- " plt.tight_layout()\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9ac8e706",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define which experiments from which sweeps\n",
- "sweep_experiments = {\n",
- " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\": [\n",
- " \"k_4_seqone\",\n",
- " \"k_5_seqone\",\n",
- " ],\n",
- " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\": [\n",
- " \"k_2_seqone\",\n",
- " \"k_3_seqone\",\n",
- " ],\n",
- "}\n",
- "\n",
- "# Load template (assuming same template across all)\n",
- "template_path = os.path.join(\n",
- " list(sweep_experiments.keys())[0], # First sweep\n",
- " list(sweep_experiments.values())[0][0], # First experiment\n",
- " \"seed_0\",\n",
- " \"template.npy\",\n",
- ")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "# Create color mapping based on parameter across all sweeps\n",
- "color_map, scalar_map, k_values = create_color_mapping_from_multiple_sweeps(\n",
- " sweep_experiments,\n",
- " \"data.k\", # or 'data.batch_size', 'training.learning_rate', etc.\n",
- " cmap=\"viridis\",\n",
- " log_scale=False,\n",
- ")\n",
- "\n",
- "# Plot\n",
- "plot_loss_comparison_multi_sweep(\n",
- " sweep_experiments,\n",
- " template_2d=template_2d,\n",
- " p1=p1,\n",
- " p2=p2,\n",
- " color_mapping=color_map,\n",
- " colorbar_label=\"k (sequence length)\",\n",
- " scalar_map=scalar_map,\n",
- " num_theory_lines=10,\n",
- " remove_outliers=False,\n",
- " log_scale=True,\n",
- ")\n",
- "\n",
- "# Print extracted values to verify\n",
- "print(\"Extracted k values:\", k_values)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1019938c",
- "metadata": {},
- "outputs": [],
- "source": [
- "def calculate_time_to_plateau_multi_sweep(\n",
- " sweep_experiments: Dict[str, List[str]],\n",
- " template_2d: np.ndarray,\n",
- " p1: int,\n",
- " p2: int,\n",
- " target_plateau_idx: int = 1,\n",
- " loss_type: str = \"train\",\n",
- " seed: int = 0,\n",
- " tolerance: float = 1.1,\n",
- ") -> tuple:\n",
- " \"\"\"\n",
- " Calculate time to plateau for experiments across multiple sweeps.\n",
- "\n",
- " Args:\n",
- " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n",
- " template_2d: 2D template array\n",
- " p1, p2: Template dimensions\n",
- " target_plateau_idx: Which plateau to measure (1 = first drop)\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- " tolerance: Multiplier for plateau threshold\n",
- "\n",
- " Returns:\n",
- " Tuple of (times_to_plateau dict, theory_levels array)\n",
- " \"\"\"\n",
- " from gagf.rnns.utils import get_power_2d_adele\n",
- "\n",
- " # Compute theory levels\n",
- " _, _, power = get_power_2d_adele(template_2d)\n",
- " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- " alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))])\n",
- " coef = 1.0 / (p1 * p2)\n",
- " theory_levels = coef * alpha_values\n",
- "\n",
- " if target_plateau_idx >= len(theory_levels):\n",
- " raise ValueError(\n",
- " f\"target_plateau_idx {target_plateau_idx} exceeds available plateaus ({len(theory_levels)})\"\n",
- " )\n",
- "\n",
- " target_level = theory_levels[target_plateau_idx]\n",
- " threshold = tolerance * target_level\n",
- "\n",
- " times_to_plateau = {}\n",
- "\n",
- " # Process all experiments across all sweeps\n",
- " for sweep_dir, exp_names in sweep_experiments.items():\n",
- " for exp_name in exp_names:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " loss_history = losses[loss_type]\n",
- "\n",
- " # Find first step where loss drops below threshold\n",
- " crossing_indices = np.where(loss_history <= threshold)[0]\n",
- "\n",
- " if len(crossing_indices) > 0:\n",
- " times_to_plateau[exp_name] = crossing_indices[0]\n",
- " else:\n",
- " times_to_plateau[exp_name] = None # Never reached\n",
- "\n",
- " return times_to_plateau, theory_levels\n",
- "\n",
- "\n",
- "def plot_time_to_plateau_multi_sweep(\n",
- " sweep_experiments: Dict[str, List[str]],\n",
- " template_2d: np.ndarray,\n",
- " p1: int,\n",
- " p2: int,\n",
- " target_plateau_idx: int = 1,\n",
- " loss_type: str = \"train\",\n",
- " seed: int = 0,\n",
- " tolerance: float = 1.1,\n",
- " figsize: tuple = (10, 6),\n",
- " sort_by: str = \"time\",\n",
- " color_mapping: Optional[Dict[str, tuple]] = None,\n",
- " colorbar_label: Optional[str] = None,\n",
- " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n",
- " show_not_reached: bool = True,\n",
- "):\n",
- " \"\"\"\n",
- " Plot time to reach a target plateau for experiments across multiple sweeps.\n",
- "\n",
- " Args:\n",
- " sweep_experiments: Dictionary mapping sweep_dir -> list of experiment names\n",
- " template_2d: 2D template array\n",
- " p1, p2: Template dimensions\n",
- " target_plateau_idx: Which plateau to measure (1 = first drop)\n",
- " loss_type: 'train' or 'val'\n",
- " seed: Seed number\n",
- " tolerance: Multiplier for plateau threshold\n",
- " figsize: Figure size\n",
- " sort_by: 'time' (sort by time to plateau) or 'name' (alphabetical)\n",
- " color_mapping: Dictionary mapping experiment names to RGBA colors\n",
- " colorbar_label: Label for colorbar\n",
- " scalar_map: ScalarMappable for colorbar\n",
- " show_not_reached: Whether to show experiments that didn't reach plateau\n",
- " \"\"\"\n",
- " times, theory_levels = calculate_time_to_plateau_multi_sweep(\n",
- " sweep_experiments,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx,\n",
- " loss_type,\n",
- " seed,\n",
- " tolerance,\n",
- " )\n",
- "\n",
- " # Separate reached vs not reached\n",
- " reached = {k: v for k, v in times.items() if v is not None}\n",
- " not_reached = {k: v for k, v in times.items() if v is None}\n",
- "\n",
- " if not reached and not not_reached:\n",
- " print(\"No experiments found!\")\n",
- " return\n",
- "\n",
- " # Sort experiments that reached the plateau\n",
- " if reached:\n",
- " if sort_by == \"time\":\n",
- " sorted_reached = sorted(reached.items(), key=lambda x: x[1])\n",
- " else: # alphabetical\n",
- " sorted_reached = sorted(reached.items(), key=lambda x: x[0])\n",
- " else:\n",
- " sorted_reached = []\n",
- "\n",
- " # Add not-reached experiments at the end if requested\n",
- " if show_not_reached and not_reached:\n",
- " not_reached_sorted = sorted(not_reached.keys())\n",
- " exp_names = [item[0] for item in sorted_reached] + not_reached_sorted\n",
- " # Use a very large value for visualization (e.g., max reached time * 1.5)\n",
- " max_reached_time = max([v for v in reached.values()]) if reached else 10000\n",
- " placeholder_time = max_reached_time * 1.3\n",
- " steps = [item[1] for item in sorted_reached] + [placeholder_time] * len(\n",
- " not_reached_sorted\n",
- " )\n",
- " reached_mask = [True] * len(sorted_reached) + [False] * len(not_reached_sorted)\n",
- " else:\n",
- " exp_names = [item[0] for item in sorted_reached]\n",
- " steps = [item[1] for item in sorted_reached]\n",
- " reached_mask = [True] * len(sorted_reached)\n",
- "\n",
- " if not exp_names:\n",
- " print(\"No experiments to plot!\")\n",
- " return\n",
- "\n",
- " # Create bar plot\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- " bars = ax.bar(range(len(exp_names)), steps, alpha=0.7, edgecolor=\"black\")\n",
- "\n",
- " # Color bars\n",
- " for i, (bar, exp_name, did_reach) in enumerate(zip(bars, exp_names, reached_mask)):\n",
- " if not did_reach:\n",
- " # Gray out experiments that didn't reach plateau\n",
- " bar.set_color(\"lightgray\")\n",
- " bar.set_alpha(0.4)\n",
- " bar.set_hatch(\"///\")\n",
- " elif color_mapping and exp_name in color_mapping:\n",
- " bar.set_color(color_mapping[exp_name])\n",
- " else:\n",
- " # Default gradient coloring\n",
- " colors = plt.cm.viridis(\n",
- " np.linspace(0.2, 0.9, len([m for m in reached_mask if m]))\n",
- " )\n",
- " bar.set_color(colors[i] if did_reach else \"lightgray\")\n",
- "\n",
- " ax.set_xticks(range(len(exp_names)))\n",
- " ax.set_xticklabels(exp_names, rotation=45, ha=\"right\")\n",
- " ax.set_xlabel(\"Experiment\", fontsize=12)\n",
- " ax.set_ylabel(\"Steps to Reach Plateau\", fontsize=12)\n",
- " ax.set_title(\n",
- " f\"Time to Reach Plateau {target_plateau_idx} (Level: {theory_levels[target_plateau_idx]:.2e})\",\n",
- " fontsize=13,\n",
- " )\n",
- " ax.grid(True, alpha=0.3, axis=\"y\")\n",
- "\n",
- " # Add value labels on bars\n",
- " for i, (bar, step, did_reach, exp_name) in enumerate(\n",
- " zip(bars, steps, reached_mask, exp_names)\n",
- " ):\n",
- " height = bar.get_height()\n",
- " if did_reach:\n",
- " ax.text(\n",
- " bar.get_x() + bar.get_width() / 2.0,\n",
- " height,\n",
- " f\"{int(step):,}\",\n",
- " ha=\"center\",\n",
- " va=\"bottom\",\n",
- " fontsize=9,\n",
- " )\n",
- " else:\n",
- " # Add \"Did not reach\" annotation\n",
- " ax.text(\n",
- " bar.get_x() + bar.get_width() / 2.0,\n",
- " height / 2,\n",
- " \"Did not\\nreach\",\n",
- " ha=\"center\",\n",
- " va=\"center\",\n",
- " fontsize=8,\n",
- " style=\"italic\",\n",
- " color=\"darkgray\",\n",
- " weight=\"bold\",\n",
- " )\n",
- "\n",
- " # Add colorbar if color mapping provided (only for experiments that reached)\n",
- " if color_mapping and scalar_map is not None and reached:\n",
- " label = colorbar_label if colorbar_label else \"Parameter Value\"\n",
- " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n",
- " cbar.ax.tick_params(labelsize=10)\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " # Print summary\n",
- " print(\n",
- " f\"\\nTime to reach plateau {target_plateau_idx} (threshold: {theory_levels[target_plateau_idx]:.2e}):\"\n",
- " )\n",
- " print(\"-\" * 60)\n",
- " if reached:\n",
- " for name, step in sorted_reached:\n",
- " print(f\" {name:30s}: {step:8,} steps\")\n",
- "\n",
- " if not_reached:\n",
- " print(f\"\\nExperiments that did not reach plateau {target_plateau_idx}:\")\n",
- " for name in not_reached_sorted:\n",
- " print(f\" - {name}\")\n",
- "\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, times, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "84c284b1",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define your experiments across multiple sweeps\n",
- "sweep_experiments = {\n",
- " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\": [\n",
- " \"k_4_seqone\",\n",
- " \"k_5_seqone\",\n",
- " ],\n",
- " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_145528\": [\n",
- " \"k_2_seqone\",\n",
- " \"k_3_seqone\",\n",
- " ],\n",
- "}\n",
- "\n",
- "# Load template\n",
- "template_path = os.path.join(\n",
- " \"/home/facosta/group-agf/sweeps/optim_sweep_20251113_001549\",\n",
- " \"k_4_seqone\",\n",
- " \"seed_0\",\n",
- " \"template.npy\",\n",
- ")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "# Create color mapping by k value\n",
- "color_map, scalar_map, k_values = create_color_mapping_from_multiple_sweeps(\n",
- " sweep_experiments, \"data.k\", cmap=\"viridis\", log_scale=False\n",
- ")\n",
- "\n",
- "# Plot time to first plateau\n",
- "fig, ax, times, levels = plot_time_to_plateau_multi_sweep(\n",
- " sweep_experiments,\n",
- " template_2d,\n",
- " p1,\n",
- " p2,\n",
- " target_plateau_idx=1, # First drop\n",
- " tolerance=1.1,\n",
- " color_mapping=color_map,\n",
- " colorbar_label=\"k (sequence length)\",\n",
- " scalar_map=scalar_map,\n",
- " sort_by=\"name\", # or 'time'\n",
- " show_not_reached=True, # Show experiments that didn't reach plateau\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "78bfc515",
- "metadata": {},
- "source": [
- "## Batch size sweep (seq-to-one, k=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ac01c062",
- "metadata": {},
- "outputs": [],
- "source": [
- "import yaml\n",
- "\n",
- "\n",
- "def extract_config_parameter(\n",
- " sweep_dir: str, experiment_name: str, parameter_path: str, seed: int = 0\n",
- ") -> any:\n",
- " \"\"\"\n",
- " Extract a parameter value from an experiment's config file.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " experiment_name: Name of experiment\n",
- " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size', 'training.learning_rate')\n",
- " seed: Seed number (for getting from seed_X/config.yaml)\n",
- "\n",
- " Returns:\n",
- " Parameter value or None if not found\n",
- " \"\"\"\n",
- " # Try configs directory first\n",
- " config_path = Path(sweep_dir) / \"configs\" / f\"{experiment_name}_config.yaml\"\n",
- "\n",
- " # If not there, try seed_X directory\n",
- " if not config_path.exists():\n",
- " config_path = Path(sweep_dir) / experiment_name / f\"seed_{seed}\" / \"config.yaml\"\n",
- "\n",
- " if not config_path.exists():\n",
- " return None\n",
- "\n",
- " try:\n",
- " with open(config_path, \"r\") as f:\n",
- " config = yaml.safe_load(f)\n",
- "\n",
- " # Navigate through nested structure using dot notation\n",
- " value = config\n",
- " for key in parameter_path.split(\".\"):\n",
- " value = value[key]\n",
- "\n",
- " return value\n",
- " except (KeyError, TypeError):\n",
- " return None\n",
- "\n",
- "\n",
- "def create_color_mapping(\n",
- " sweep_dir: str,\n",
- " parameter_path: str,\n",
- " experiments: Optional[List[str]] = None,\n",
- " cmap: str = \"viridis\",\n",
- " seed: int = 0,\n",
- " log_scale: bool = False,\n",
- ") -> tuple:\n",
- " \"\"\"\n",
- " Create a color mapping for experiments based on a config parameter.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " parameter_path: Dot-separated path to parameter (e.g., 'data.batch_size')\n",
- " experiments: List of experiment names (None = all)\n",
- " cmap: Colormap name\n",
- " seed: Seed number\n",
- " log_scale: Whether to use logarithmic scale for color mapping\n",
- "\n",
- " Returns:\n",
- " Tuple of (color_mapping dict, scalar_map, param_values dict)\n",
- " \"\"\"\n",
- " if experiments is None:\n",
- " experiments = get_sweep_experiments(sweep_dir)\n",
- "\n",
- " # Extract parameter values for all experiments\n",
- " param_values = {}\n",
- " for exp_name in experiments:\n",
- " value = extract_config_parameter(sweep_dir, exp_name, parameter_path, seed)\n",
- " if value is not None:\n",
- " param_values[exp_name] = value\n",
- "\n",
- " if not param_values:\n",
- " print(f\"Warning: Could not extract '{parameter_path}' from any experiments\")\n",
- " return {}, None, {}\n",
- "\n",
- " # Create color mapping\n",
- " values = list(param_values.values())\n",
- " v_min = min(values)\n",
- " v_max = max(values)\n",
- "\n",
- " # Use log or linear normalization\n",
- " if log_scale:\n",
- " if v_min <= 0:\n",
- " print(\n",
- " f\"Warning: log_scale requested but found non-positive values (min={v_min}). Using linear scale.\"\n",
- " )\n",
- " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n",
- " else:\n",
- " norm = plt.cm.colors.LogNorm(vmin=v_min, vmax=v_max)\n",
- " else:\n",
- " norm = plt.cm.colors.Normalize(vmin=v_min, vmax=v_max)\n",
- "\n",
- " colormap = plt.cm.get_cmap(cmap)\n",
- "\n",
- " color_mapping = {}\n",
- " for exp_name, value in param_values.items():\n",
- " color_mapping[exp_name] = colormap(norm(value))\n",
- "\n",
- " # Also return the scalar mappable for colorbar\n",
- " scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=colormap)\n",
- "\n",
- " return color_mapping, scalar_map, param_values"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b2da02fb",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_loss_comparison(\n",
- " sweep_dir: str,\n",
- " experiments: Optional[List[str]] = None,\n",
- " loss_type: str = \"train\",\n",
- " log_scale: bool = True,\n",
- " figsize: tuple = (10, 6),\n",
- " seed: int = 0,\n",
- " remove_outliers: bool = False,\n",
- " outlier_window: int = 10,\n",
- " outlier_threshold: float = 3.0,\n",
- " template_2d: Optional[np.ndarray] = None,\n",
- " p1: Optional[int] = None,\n",
- " p2: Optional[int] = None,\n",
- " show_theory_bands: bool = True,\n",
- " num_theory_lines: Optional[int] = None,\n",
- " color_mapping: Optional[Dict[str, tuple]] = None,\n",
- " colorbar_label: Optional[str] = None,\n",
- " scalar_map: Optional[plt.cm.ScalarMappable] = None,\n",
- "):\n",
- " \"\"\"\n",
- " Plot and compare loss curves from multiple experiments.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " experiments: List of experiment names to plot (None = all experiments)\n",
- " loss_type: 'train' or 'val'\n",
- " log_scale: Whether to use log scale for both axes\n",
- " figsize: Figure size tuple\n",
- " seed: Seed number (default: 0)\n",
- " remove_outliers: Whether to remove outliers using local outlier replacement\n",
- " outlier_window: Window size for outlier detection (default: 10)\n",
- " outlier_threshold: Threshold in standard deviations for outlier detection (default: 3.0)\n",
- " template_2d: Optional 2D template array for computing theory lines\n",
- " p1: First dimension of template (required if template_2d is provided)\n",
- " p2: Second dimension of template (required if template_2d is provided)\n",
- " show_theory_bands: Whether to show colored bands between theory lines (default: True)\n",
- " num_theory_lines: Number of theory lines to show (default: None = show all)\n",
- " color_mapping: Dictionary mapping experiment names to RGBA colors\n",
- " colorbar_label: Label for colorbar (if color_mapping provided)\n",
- " scalar_map: ScalarMappable for colorbar (if color_mapping provided)\n",
- " \"\"\"\n",
- " if experiments is None:\n",
- " experiments = get_sweep_experiments(sweep_dir)\n",
- "\n",
- " fig, ax = plt.subplots(figsize=figsize)\n",
- "\n",
- " # Compute theory lines if template is provided\n",
- " theory_levels = None\n",
- " if template_2d is not None:\n",
- " if p1 is None or p2 is None:\n",
- " raise ValueError(\"p1 and p2 must be provided if template_2d is given\")\n",
- "\n",
- " # Import the helper function (assuming it's in utils.py)\n",
- " from gagf.rnns.utils import get_power_2d_adele\n",
- "\n",
- " # Compute power spectrum of template\n",
- " _, _, power = get_power_2d_adele(template_2d)\n",
- " power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1]\n",
- "\n",
- " # Theory levels (cumulative tail sums)\n",
- " alpha_values = np.array(\n",
- " [np.sum(power_flat[k:]) for k in range(len(power_flat))]\n",
- " )\n",
- " coef = 1.0 / (p1 * p2)\n",
- " theory_levels = coef * alpha_values # strictly decreasing\n",
- "\n",
- " # Limit number of lines if specified\n",
- " if num_theory_lines is not None:\n",
- " theory_levels = theory_levels[: num_theory_lines + 1]\n",
- "\n",
- " # Generate colors for bands\n",
- " n_bands = len(theory_levels) - 1\n",
- " colors = plt.cm.tab10(np.linspace(0, 1, max(n_bands, 1)))\n",
- "\n",
- " # Draw colored bands between theory lines\n",
- " if show_theory_bands and n_bands > 0:\n",
- " for i in range(n_bands):\n",
- " y_top = theory_levels[i]\n",
- " y_bot = theory_levels[i + 1]\n",
- " ax.axhspan(\n",
- " y_bot,\n",
- " y_top,\n",
- " facecolor=colors[i % len(colors)],\n",
- " alpha=0.15,\n",
- " zorder=-3,\n",
- " )\n",
- "\n",
- " # Draw the black theory lines\n",
- " for y in theory_levels:\n",
- " ax.axhline(\n",
- " y=y,\n",
- " color=\"black\",\n",
- " linestyle=\"--\",\n",
- " linewidth=1.5,\n",
- " alpha=0.7,\n",
- " zorder=-2,\n",
- " label=\"_nolegend_\",\n",
- " )\n",
- "\n",
- " # Plot loss curves\n",
- " for exp_name in experiments:\n",
- " losses = load_experiment_losses(sweep_dir, exp_name, seed)\n",
- " if loss_type in losses:\n",
- " loss_history = losses[loss_type]\n",
- "\n",
- " # Apply outlier removal if requested\n",
- " if remove_outliers:\n",
- " loss_history, outliers_found = remove_outliers_local(\n",
- " loss_history, window=outlier_window, threshold=outlier_threshold\n",
- " )\n",
- " if outliers_found:\n",
- " print(f\"Outliers from {exp_name} removed for plot\")\n",
- "\n",
- " # Determine color\n",
- " if color_mapping and exp_name in color_mapping:\n",
- " color = color_mapping[exp_name]\n",
- " else:\n",
- " color = None # Use default color cycle\n",
- "\n",
- " ax.plot(loss_history, label=exp_name, alpha=0.8, linewidth=2, color=color)\n",
- "\n",
- " ax.set_xlabel(\"Step\", fontsize=14)\n",
- " ax.set_ylabel(f\"{loss_type.capitalize()} Loss\", fontsize=14)\n",
- " title = f\"{loss_type.capitalize()} Loss Comparison - {Path(sweep_dir).name}\"\n",
- " if remove_outliers:\n",
- " title += \" (outliers removed)\"\n",
- " ax.set_title(title, fontsize=14)\n",
- " ax.legend(fontsize=10)\n",
- " ax.grid(True, alpha=0.3)\n",
- "\n",
- " if log_scale:\n",
- " ax.set_xscale(\"log\")\n",
- " ax.set_yscale(\"log\")\n",
- "\n",
- " # Add colorbar if color mapping provided\n",
- " if color_mapping and scalar_map is not None:\n",
- " label = colorbar_label if colorbar_label else \"Parameter Value\"\n",
- " cbar = plt.colorbar(scalar_map, ax=ax, label=label, pad=0.02)\n",
- " cbar.ax.tick_params(labelsize=10)\n",
- "\n",
- " plt.tight_layout()\n",
- " plt.show()\n",
- "\n",
- " return fig, ax, theory_levels"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8b2ce522",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Set up your sweep directory\n",
- "batch_sweep_dir = \"/home/facosta/group-agf/sweeps/batch_sweep_20251113_171834\"\n",
- "\n",
- "# Get all experiments in the sweep\n",
- "batch_experiments = get_sweep_experiments(batch_sweep_dir)\n",
- "print(f\"Found experiments: {batch_experiments}\")\n",
- "\n",
- "# Load losses for all experiments\n",
- "batch_all_losses = load_all_sweep_losses(batch_sweep_dir)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "520500f3",
- "metadata": {},
- "outputs": [],
- "source": [
- "# First, load the template from one of your experiments\n",
- "# (assuming they all use the same template)\n",
- "template_path = os.path.join(batch_sweep_dir, \"batch_1000\", \"seed_0\", \"template.npy\")\n",
- "template_2d = np.load(template_path)\n",
- "p1, p2 = template_2d.shape\n",
- "\n",
- "plt.imshow(template_2d)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b9761849",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example 1: Color by batch_size\n",
- "color_map, scalar_map, batch_values = create_color_mapping(\n",
- " batch_sweep_dir,\n",
- " \"data.batch_size\",\n",
- " cmap=\"viridis\",\n",
- " log_scale=True,\n",
- ")\n",
- "\n",
- "plot_loss_comparison(\n",
- " batch_sweep_dir,\n",
- " template_2d=template_2d,\n",
- " p1=p1,\n",
- " p2=p2,\n",
- " color_mapping=color_map,\n",
- " colorbar_label=\"Batch Size\",\n",
- " scalar_map=scalar_map,\n",
- " num_theory_lines=10,\n",
- " remove_outliers=True,\n",
- " log_scale=False,\n",
- ")\n",
- "\n",
- "\n",
- "# Print the extracted values to verify\n",
- "print(\"Batch sizes:\", batch_values)\n",
- "\n",
- "# Print the theory levels (plateau values)\n",
- "print(\"Theory plateau levels:\")\n",
- "for i, level in enumerate(theory_levels):\n",
- " print(f\" Plateau {i}: {level:.6e}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7acdd373",
- "metadata": {},
- "source": [
- "## MLP Scaling Sweep Heatmap\n",
- "\n",
- "Visualize the relationship between sequence length (k) and hidden dimension (width) for SequentialMLP.\n",
- "\n",
- "**Sweep parameters:**\n",
- "- Model: SequentialMLP\n",
- "- Dimension: 1, p = 10\n",
- "- k values: 2, 3, 4, 5, 6 (5 values)\n",
- "- hidden_dim values: 60, 360, 2160, 12960, 77760 (5 values = 10×6¹ through 10×6⁵)\n",
- "- num_steps varies with k: k=2→50k, k=3→100k, k=4→150k, k=5→200k, k=6→250k\n",
- "- Total: 25 experiments × 3 seeds = 75 runs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "92d3289e",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_scaling_heatmap(\n",
- " sweep_dir: str,\n",
- " k_values: list = [2, 3, 4, 5, 6, 7, 8],\n",
- " hidden_dims: list = [60, 360, 2160, 12960, 77760, 466560, 2799360],\n",
- " use_log_scale: bool = True,\n",
- " save_path: str = None\n",
- "): \n",
- " \"\"\"\n",
- " Create a heatmap of final train loss vs k and hidden_dim.\n",
- " \n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k values (x-axis)\n",
- " hidden_dims: List of hidden dimension values (y-axis)\n",
- " use_log_scale: Whether to use log scale for the loss values\n",
- " save_path: Optional path to save the figure\n",
- " \"\"\"\n",
- " # Load results\n",
- " grid, std_grid = load_sweep_results_grid(sweep_dir, k_values, hidden_dims)\n",
- " \n",
- " # Apply log scale if requested\n",
- " plot_grid = np.log10(grid) if use_log_scale else grid\n",
- " \n",
- " # Create figure\n",
- " fig, ax = plt.subplots(figsize=(12, 10)) \n",
- " \n",
- " \n",
- " # Create heatmap\n",
- " im = ax.imshow(plot_grid, aspect='auto', cmap='viridis', origin='lower')\n",
- " \n",
- " # Set ticks and labels\n",
- " ax.set_xticks(range(len(k_values)))\n",
- " ax.set_yticks(range(len(hidden_dims)))\n",
- " ax.set_xticklabels(k_values)\n",
- " ax.set_yticklabels([f\"{h:,}\" for h in hidden_dims])\n",
- " \n",
- " # Labels\n",
- " ax.set_xlabel('Sequence Length (k)', fontsize=14)\n",
- " ax.set_ylabel('Hidden Dimension (width)', fontsize=14)\n",
- " title = 'Final Train Loss: SequentialMLP Scaling'\n",
- " if use_log_scale:\n",
- " title += ' (log₁₀)'\n",
- " ax.set_title(title, fontsize=16, pad=20)\n",
- " \n",
- " # Add colorbar\n",
- " cbar = plt.colorbar(im, ax=ax)\n",
- " cbar_label = 'log₁₀(Train Loss)' if use_log_scale else 'Train Loss'\n",
- " cbar.set_label(cbar_label, fontsize=12)\n",
- " \n",
- " # Add text annotations\n",
- "\n",
- " # Add text annotations\n",
- " for i in range(len(hidden_dims)):\n",
- " for j in range(len(k_values)):\n",
- " if not np.isnan(grid[i, j]):\n",
- " text_val = f\"{grid[i, j]:.2e}\"\n",
- " text_color = 'white' if plot_grid[i, j] < plot_grid[~np.isnan(plot_grid)].mean() else 'black'\n",
- " ax.text(j, i, text_val, ha='center', va='center', \n",
- " color=text_color, fontsize=8)\n",
- " \n",
- " plt.tight_layout()\n",
- " \n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
- " print(f\"Figure saved to: {save_path}\")\n",
- " \n",
- " plt.show()\n",
- " \n",
- " return grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2e5c22df",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b10c9d83",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/seq_mlp.ipynb b/notebooks/seq_mlp.ipynb
deleted file mode 100644
index 5dd9168..0000000
--- a/notebooks/seq_mlp.ipynb
+++ /dev/null
@@ -1,2071 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "4265f1a8",
- "metadata": {},
- "source": [
- "# MLP Scaling: $H$ vs $k$\n",
- "\n",
- "Hidden neurons vs sequence length scaling experiments."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5a05ce99",
- "metadata": {},
- "source": [
- "## Set up"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5cfe1142",
- "metadata": {},
- "outputs": [],
- "source": [
- "# autoreload\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "# jupyter black formatter\n",
- "%load_ext jupyter_black\n",
- "\n",
- "import subprocess\n",
- "import os\n",
- "import sys\n",
- "\n",
- "gitroot_path = subprocess.check_output(\n",
- " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n",
- ").strip()\n",
- "\n",
- "os.chdir(gitroot_path)\n",
- "print(\"Working directory: \", os.getcwd())\n",
- "\n",
- "if gitroot_path not in sys.path:\n",
- " sys.path.insert(0, gitroot_path)\n",
- "print(\"Directory added to path: \", gitroot_path)\n",
- "\n",
- "import yaml\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from pathlib import Path"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "15a42140",
- "metadata": {},
- "source": [
- "## Specify experiment directory"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2ebd6750",
- "metadata": {},
- "outputs": [],
- "source": [
- "sweep_dir = \"/data/facosta/sweeps/sweep_mlp_scaling_20251212_161329\"\n",
- "os.path.exists(sweep_dir)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "cb3acce2",
- "metadata": {},
- "source": [
- "### Final Loss Heatmap"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "af291059",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid(sweep_dir: str, k_values: list, hidden_dims: list):\n",
- " \"\"\"\n",
- " Load sweep results and organize into a grid for heatmap visualization.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- "\n",
- " Returns:\n",
- " grid: 2D numpy array with shape (len(hidden_dims), len(k_values))\n",
- " containing mean final train losses\n",
- " std_grid: 2D numpy array with standard deviations (if multiple seeds)\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " # Initialize grids\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " # Load results for each experiment\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " print(f\"Warning: Experiment {exp_name} not found\")\n",
- " continue\n",
- "\n",
- " # Load experiment summary\n",
- " summary_file = exp_dir / \"experiment_summary.yaml\"\n",
- " if summary_file.exists():\n",
- " with open(summary_file, \"r\") as f:\n",
- " summary = yaml.safe_load(f)\n",
- "\n",
- " # Get mean train loss\n",
- " if \"train_loss_stats\" in summary:\n",
- " grid[i, j] = summary[\"train_loss_stats\"][\"mean\"]\n",
- " std_grid[i, j] = summary[\"train_loss_stats\"][\"std\"]\n",
- " else:\n",
- " print(f\"Warning: No train_loss_stats in {exp_name}\")\n",
- " else:\n",
- " print(f\"Warning: No summary file for {exp_name}\")\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "54e321b2",
- "metadata": {},
- "outputs": [],
- "source": [
- "k_values = [2, 3, 4, 5, 6, 7, 8]\n",
- "\n",
- "# hidden_dims = [60, 360, 2160, 12960, 77760]\n",
- "hidden_dims = [6, 6**2, 6**3, 6**4, 6**5, 6**6]\n",
- "\n",
- "grid, _ = load_sweep_results_grid(sweep_dir, k_values, hidden_dims)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2f62021a",
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.figure(figsize=(8, 6))\n",
- "plt.imshow(grid, aspect=\"auto\", norm=None)\n",
- "\n",
- "# Correct labels: rows = hidden_dims (y), columns = k_values (x)\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "\n",
- "# Set tick labels to show actual values\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- "\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "plt.title(\"Final Train Loss\")\n",
- "\n",
- "plt.colorbar(label=\"Final Train Loss\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "fbcf4b02",
- "metadata": {},
- "source": [
- "### Loss Curve Integral Heatmap"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4bb980bb",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_integral(sweep_dir: str, k_values: list, hidden_dims: list):\n",
- " \"\"\"\n",
- " Load sweep results and compute integral of loss curves.\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean integral of loss curves\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Collect integrals from all seeds\n",
- " integrals = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " # Compute integral using trapezoidal rule\n",
- " integral = np.trapz(loss_history)\n",
- " integrals.append(integral)\n",
- "\n",
- " if integrals:\n",
- " grid[i, j] = np.mean(integrals)\n",
- " std_grid[i, j] = np.std(integrals) if len(integrals) > 1 else 0.0\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8e5017cd",
- "metadata": {},
- "outputs": [],
- "source": [
- "integral_grid, integral_std = load_sweep_results_grid_integral(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "from matplotlib.colors import LogNorm, SymLogNorm\n",
- "\n",
- "plt.figure(figsize=(8, 6))\n",
- "plt.imshow(integral_grid, aspect=\"auto\", norm=LogNorm())\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- "\n",
- "plt.gca().invert_yaxis()\n",
- "plt.colorbar(label=\"Loss Curve Integral\")\n",
- "plt.title(\"Loss Curve Integral\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6b9956cf",
- "metadata": {},
- "source": [
- "### Steps to Convergence Heatmap"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4bf3dbcc",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_convergence(\n",
- " sweep_dir: str, k_values: list, hidden_dims: list, reduction_threshold: float = 0.99\n",
- "):\n",
- " \"\"\"\n",
- " Load sweep results and compute steps to convergence.\n",
- "\n",
- " Convergence is defined as reaching `reduction_threshold` loss reduction\n",
- " (e.g., 0.99 = 99% reduction from initial loss).\n",
- "\n",
- " If convergence is not reached, the grid point is set to NaN (blacked out).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " reduction_threshold: Fraction of loss reduction to consider converged\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Collect convergence steps from all seeds\n",
- " convergence_steps = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " initial_loss = loss_history[0]\n",
- "\n",
- " if initial_loss > 0:\n",
- " # Compute reduction at each step\n",
- " reductions = 1 - loss_history / initial_loss\n",
- "\n",
- " # Find first step where reduction >= threshold\n",
- " converged_mask = reductions >= reduction_threshold\n",
- " if np.any(converged_mask):\n",
- " step = np.argmax(converged_mask) # First True\n",
- " convergence_steps.append(step)\n",
- " # else: Never converged - don't add to list\n",
- "\n",
- " if convergence_steps:\n",
- " grid[i, j] = np.mean(convergence_steps)\n",
- " std_grid[i, j] = (\n",
- " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n",
- " )\n",
- " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6b083a66",
- "metadata": {},
- "outputs": [],
- "source": [
- "reduction_threshold = 0.6\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n",
- ")\n",
- "plt.figure(figsize=(10, 6)) # Made slightly wider to accommodate legend\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length ($k$)\")\n",
- "plt.ylabel(\"Hidden Dimension $H$\")\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "\n",
- "# Create y-tick labels with both power notation and actual values\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n",
- "\n",
- "plt.step(\n",
- " x_step,\n",
- " y_step,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=\"Theory boundary ($H > 6^{k-1}$)\",\n",
- ")\n",
- "\n",
- "# Place legend outside the plot area (to the right)\n",
- "plt.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True)\n",
- "\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "plt.title(f\"Steps to {reduction_threshold*100}% Convergence (black = did not converge)\")\n",
- "plt.tight_layout() # Adjust layout to prevent clipping\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "130132a9",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_spikiness(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dims: list,\n",
- "):\n",
- " \"\"\"\n",
- " Compute fraction of training steps where loss increased (instability).\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean frac_upward across seeds\n",
- " std_grid: 2D array with standard deviations\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " frac_upwards = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " log_loss = np.log10(loss_history + 1e-10)\n",
- " log_changes = np.diff(log_loss)\n",
- "\n",
- " # Fraction of steps where loss went UP\n",
- " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n",
- " frac_upwards.append(frac_upward)\n",
- "\n",
- " if frac_upwards:\n",
- " grid[i, j] = np.mean(frac_upwards)\n",
- " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4fba73d0",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load both convergence and spikiness data\n",
- "reduction_threshold = 0.6 # Adjust as needed\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n",
- ")\n",
- "\n",
- "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "# Create a masked version of stability_grid where non-converged runs are NaN\n",
- "stability_grid_masked = stability_grid.copy()\n",
- "stability_grid_masked[np.isnan(conv_grid)] = np.nan # Mask non-converged runs\n",
- "\n",
- "# Create plot\n",
- "plt.figure(figsize=(10, 6))\n",
- "\n",
- "# Use a colormap with bad values (NaN) shown as black\n",
- "cmap = plt.cm.plasma.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "\n",
- "plt.imshow(stability_grid_masked, aspect=\"equal\", cmap=cmap, vmin=0, vmax=0.5)\n",
- "\n",
- "plt.xlabel(\"Sequence Length ($k$)\")\n",
- "plt.ylabel(\"Hidden Dimension $H$\")\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "\n",
- "# Create y-tick labels with both power notation and actual values\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n",
- "\n",
- "plt.step(\n",
- " x_step,\n",
- " y_step,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=\"Theory boundary ($H > 6^{k-1}$)\",\n",
- ")\n",
- "\n",
- "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n",
- "\n",
- "plt.colorbar(label=\"Fraction of Upward Steps\")\n",
- "plt.title(\n",
- " f\"Training Instability\\n(black = did not converge)\",\n",
- " fontsize=13,\n",
- " fontweight=\"bold\",\n",
- ")\n",
- "\n",
- "plt.tight_layout()\n",
- "plt.show()\n",
- "\n",
- "# Print summary\n",
- "n_converged = np.sum(~np.isnan(conv_grid))\n",
- "n_not_converged = np.sum(np.isnan(conv_grid))\n",
- "print(f\"\\n{'='*60}\")\n",
- "print(f\"Converged runs: {n_converged} ({100*n_converged/conv_grid.size:.1f}%)\")\n",
- "print(\n",
- " f\"Did not converge (black): {n_not_converged} ({100*n_not_converged/conv_grid.size:.1f}%)\"\n",
- ")\n",
- "print(f\"{'='*60}\\n\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7a228687",
- "metadata": {},
- "outputs": [],
- "source": [
- "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "# Create side-by-side subplots\n",
- "plt.figure(figsize=(10, 6))\n",
- "\n",
- "# ========== RIGHT PANEL: SPIKINESS ==========\n",
- "# Use a different colormap - plasma, magma, or RdYlGn_r work well\n",
- "plt.imshow(stability_grid, aspect=\"equal\", cmap=\"plasma\", vmin=0, vmax=0.5)\n",
- "\n",
- "\n",
- "plt.xlabel(\"Sequence Length ($k$)\")\n",
- "plt.ylabel(\"Hidden Dimension $H$\")\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "\n",
- "# Create y-tick labels with both power notation and actual values\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "y_step = np.minimum(x_step, len(hidden_dims)) # Example: stays within bounds\n",
- "\n",
- "plt.step(\n",
- " x_step,\n",
- " y_step,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=\"Theory boundary ($H > 6^{k-1}$)\",\n",
- ")\n",
- "\n",
- "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n",
- "\n",
- "plt.colorbar(label=\"Fraction of Upward Steps\")\n",
- "# cbar2.ax.axhline(0.3, color=\"white\", linewidth=2, linestyle=\"--\") # Mark threshold\n",
- "plt.title(\n",
- " \"Training Instability\\n(higher = more unstable)\", fontsize=13, fontweight=\"bold\"\n",
- ")\n",
- "\n",
- "plt.tight_layout()\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5eb6bb45",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load both metrics\n",
- "reduction_threshold = 0.6\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n",
- ")\n",
- "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "# Create binary mask: 1 if converged AND spiky, 0 otherwise\n",
- "spikiness_threshold = 0.3 # Adjust this threshold as needed\n",
- "\n",
- "converged_and_spiky = np.zeros_like(conv_grid)\n",
- "for i in range(len(hidden_dims)):\n",
- " for j in range(len(k_values)):\n",
- " converged = not np.isnan(conv_grid[i, j])\n",
- " spiky = stability_grid[i, j] > spikiness_threshold\n",
- "\n",
- " if converged and spiky:\n",
- " converged_and_spiky[i, j] = 1.0\n",
- " else:\n",
- " converged_and_spiky[i, j] = np.nan # Will show as white\n",
- "\n",
- "# Plot\n",
- "plt.figure(figsize=(8, 6.5))\n",
- "\n",
- "# Custom colormap: white for NaN, red for 1\n",
- "cmap = plt.cm.Reds.copy()\n",
- "cmap.set_bad(color=\"white\")\n",
- "\n",
- "im = plt.imshow(converged_and_spiky, aspect=\"equal\", cmap=cmap, vmin=0, vmax=1)\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\", fontsize=12)\n",
- "plt.ylabel(\"Hidden Dimension\", fontsize=12)\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "# Add theory boundary\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "y_step = np.minimum(x_step, len(hidden_dims))\n",
- "plt.step(\n",
- " x_step,\n",
- " y_step,\n",
- " where=\"post\",\n",
- " color=\"blue\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=\"Theory boundary\",\n",
- ")\n",
- "plt.legend(loc=\"upper left\", fontsize=10, frameon=True)\n",
- "\n",
- "# Add text annotations showing spikiness values in colored cells\n",
- "for i in range(len(hidden_dims)):\n",
- " for j in range(len(k_values)):\n",
- " if converged_and_spiky[i, j] == 1.0:\n",
- " spikiness_val = stability_grid[i, j]\n",
- " plt.text(\n",
- " j,\n",
- " i,\n",
- " f\"{spikiness_val:.2f}\",\n",
- " ha=\"center\",\n",
- " va=\"center\",\n",
- " fontsize=9,\n",
- " color=\"white\",\n",
- " fontweight=\"bold\",\n",
- " )\n",
- "\n",
- "plt.colorbar(im, label=\"Converged & Spiky\", ticks=[0, 1])\n",
- "plt.title(\n",
- " f\"Converged BUT Spiky Runs\\n(threshold: frac_upward > {spikiness_threshold})\",\n",
- " fontsize=13,\n",
- " fontweight=\"bold\",\n",
- ")\n",
- "plt.tight_layout()\n",
- "plt.show()\n",
- "\n",
- "# Print summary\n",
- "n_converged_spiky = np.sum(converged_and_spiky == 1.0)\n",
- "n_converged_total = np.sum(~np.isnan(conv_grid))\n",
- "print(f\"\\n{'='*60}\")\n",
- "print(f\"SUMMARY: Converged & Spiky Runs\")\n",
- "print(f\"{'='*60}\")\n",
- "print(f\"Spikiness threshold: {spikiness_threshold} (frac_upward)\")\n",
- "print(f\"Converged & spiky: {n_converged_spiky}\")\n",
- "print(f\"Total converged: {n_converged_total}\")\n",
- "print(f\"Percentage: {100*n_converged_spiky/n_converged_total:.1f}%\")\n",
- "print(f\"{'='*60}\\n\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c7b2bf5c",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load both metrics\n",
- "reduction_threshold = 0.6\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n",
- ")\n",
- "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "# Parameters\n",
- "spikiness_threshold = 0.3\n",
- "\n",
- "# Create modified grid for plotting\n",
- "# Strategy: Use a modified colormap and data array\n",
- "plot_grid = conv_grid.copy()\n",
- "\n",
- "# Create mask for spiky converged runs\n",
- "spiky_mask = np.zeros_like(conv_grid, dtype=bool)\n",
- "for i in range(len(hidden_dims)):\n",
- " for j in range(len(k_values)):\n",
- " converged = not np.isnan(conv_grid[i, j])\n",
- " spiky = stability_grid[i, j] > spikiness_threshold\n",
- " if converged and spiky:\n",
- " spiky_mask[i, j] = True\n",
- "\n",
- "# Plot\n",
- "fig, ax = plt.subplots(figsize=(10, 6.5))\n",
- "\n",
- "# First: plot convergence grid with viridis_r (will handle black for NaN)\n",
- "cmap_conv = plt.cm.viridis_r.copy()\n",
- "cmap_conv.set_bad(color=\"black\")\n",
- "im = ax.imshow(plot_grid, aspect=\"equal\", cmap=cmap_conv, norm=LogNorm())\n",
- "\n",
- "# Second: overlay red patches for spiky converged runs\n",
- "for i in range(len(hidden_dims)):\n",
- " for j in range(len(k_values)):\n",
- " if spiky_mask[i, j]:\n",
- " # Draw red square\n",
- " rect = plt.Rectangle(\n",
- " (j - 0.5, i - 0.5),\n",
- " 1,\n",
- " 1,\n",
- " facecolor=\"red\",\n",
- " edgecolor=\"darkred\",\n",
- " linewidth=2,\n",
- " alpha=0.9,\n",
- " )\n",
- " ax.add_patch(rect)\n",
- "\n",
- " # # Add convergence value in white text\n",
- " # conv_val = conv_grid[i, j]\n",
- " # ax.text(\n",
- " # j,\n",
- " # i,\n",
- " # f\"{int(conv_val)}\",\n",
- " # ha=\"center\",\n",
- " # va=\"center\",\n",
- " # fontsize=8,\n",
- " # color=\"white\",\n",
- " # fontweight=\"bold\",\n",
- " # )\n",
- "\n",
- "# Formatting\n",
- "ax.set_xlabel(\"Sequence Length (k)\", fontsize=12)\n",
- "ax.set_ylabel(\"Hidden Dimension\", fontsize=12)\n",
- "ax.set_xticks(range(len(k_values)))\n",
- "ax.set_xticklabels(k_values)\n",
- "\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "ax.set_yticks(range(len(hidden_dims)))\n",
- "ax.set_yticklabels(ytick_labels)\n",
- "ax.invert_yaxis()\n",
- "\n",
- "# Add theory boundary\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "y_step = np.minimum(x_step, len(hidden_dims))\n",
- "ax.step(\n",
- " x_step,\n",
- " y_step,\n",
- " where=\"post\",\n",
- " color=\"cyan\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=\"Theory boundary\",\n",
- ")\n",
- "\n",
- "# Custom legend\n",
- "from matplotlib.patches import Patch\n",
- "\n",
- "legend_elements = [\n",
- " Patch(facecolor=\"black\", label=\"Did not converge\"),\n",
- " Patch(\n",
- " facecolor=\"red\",\n",
- " edgecolor=\"darkred\",\n",
- " linewidth=2,\n",
- " label=f\"Spiky (frac_up > {spikiness_threshold})\",\n",
- " ),\n",
- " Patch(facecolor=\"yellow\", label=\"Smooth convergence\"),\n",
- " plt.Line2D(\n",
- " [0], [0], color=\"cyan\", linewidth=3, linestyle=\"--\", label=\"Theory boundary\"\n",
- " ),\n",
- "]\n",
- "ax.legend(\n",
- " handles=legend_elements,\n",
- " loc=\"upper center\",\n",
- " bbox_to_anchor=(0.5, -0.12),\n",
- " fontsize=10,\n",
- " frameon=True,\n",
- " ncol=4,\n",
- ")\n",
- "\n",
- "plt.colorbar(im, ax=ax, label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "ax.set_title(\"Convergence Speed & Spikiness Combined\", fontsize=13, fontweight=\"bold\")\n",
- "plt.tight_layout()\n",
- "plt.show()\n",
- "\n",
- "# Print summary\n",
- "n_not_converged = np.sum(np.isnan(conv_grid))\n",
- "n_converged_spiky = np.sum(spiky_mask)\n",
- "n_converged_smooth = np.sum(~np.isnan(conv_grid)) - n_converged_spiky\n",
- "total = conv_grid.size\n",
- "\n",
- "print(f\"\\n{'='*60}\")\n",
- "print(f\"SUMMARY\")\n",
- "print(f\"{'='*60}\")\n",
- "print(\n",
- " f\"Did not converge (black): {n_not_converged:3d} ({100*n_not_converged/total:.1f}%)\"\n",
- ")\n",
- "print(\n",
- " f\"Spiky converged (red): {n_converged_spiky:3d} ({100*n_converged_spiky/total:.1f}%)\"\n",
- ")\n",
- "print(\n",
- " f\"Smooth converged (color): {n_converged_smooth:3d} ({100*n_converged_smooth/total:.1f}%)\"\n",
- ")\n",
- "print(f\"{'='*60}\\n\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "2f5cbbf5",
- "metadata": {},
- "source": [
- "### Curve plot: Convergence steps vs Sequence Length $k$ for different hidden dimensions\n",
- "- x-axis: sequence length $k$\n",
- "- y-axis: number of steps to convergence\n",
- "- different curves for different hidden dimensions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "adddc2c6",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_convergence_vs_k(\n",
- " conv_grid,\n",
- " k_values,\n",
- " hidden_dims,\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=True,\n",
- " log_y=True,\n",
- " reduction_threshold=0.9,\n",
- "):\n",
- " \"\"\"\n",
- " Plot steps to convergence vs sequence length k for different hidden dimensions.\n",
- "\n",
- " Args:\n",
- " conv_grid: 2D array (len(hidden_dims), len(k_values)) with convergence steps\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " save_path: Where to save the plot\n",
- " show: Whether to display the plot\n",
- " log_x: Whether to use log scale for x-axis\n",
- " log_y: Whether to use log scale for y-axis\n",
- " reduction_threshold: Threshold used for convergence\n",
- " \"\"\"\n",
- " fig, ax = plt.subplots(figsize=(10, 6))\n",
- "\n",
- " # Use a nice sequential colormap for different widths\n",
- " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(hidden_dims)))\n",
- "\n",
- " for i, (h, color) in enumerate(zip(hidden_dims, colors)):\n",
- " # Extract convergence steps for this hidden dim across all k values\n",
- " steps_for_h = conv_grid[i, :]\n",
- "\n",
- " # Only plot converged points\n",
- " converged_mask = ~np.isnan(steps_for_h)\n",
- " k_converged = np.array(k_values)[converged_mask]\n",
- " steps_converged = steps_for_h[converged_mask]\n",
- "\n",
- " if len(steps_converged) > 0:\n",
- " # Plot with line and markers\n",
- " ax.plot(\n",
- " k_converged,\n",
- " steps_converged,\n",
- " color=color,\n",
- " marker=\"o\",\n",
- " markersize=7,\n",
- " linewidth=2.5,\n",
- " label=f\"h={h:,}\",\n",
- " markeredgewidth=0.5,\n",
- " markeredgecolor=\"white\",\n",
- " )\n",
- "\n",
- " # Formatting\n",
- " ax.set_xlabel(\"Sequence Length ($k$)\", fontsize=14)\n",
- " ax.set_ylabel(\"Steps to Convergence\", fontsize=14)\n",
- " ax.set_title(\n",
- " f\"Steps to {reduction_threshold*100}% Convergence vs Sequence Length $k$\",\n",
- " fontsize=16,\n",
- " )\n",
- " if log_y:\n",
- " ax.set_yscale(\"log\")\n",
- " if log_x:\n",
- " ax.set_xscale(\"log\")\n",
- " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
- " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n",
- "\n",
- " # Make k values discrete on x-axis\n",
- " ax.set_xticks(k_values)\n",
- " ax.set_xticklabels(k_values)\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
- " print(f\"Saved to {save_path}\")\n",
- "\n",
- " if show:\n",
- " plt.show()\n",
- " else:\n",
- " plt.close()\n",
- "\n",
- " return fig, ax\n",
- "\n",
- "\n",
- "reduction_threshold = 0.9\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir,\n",
- " k_values,\n",
- " hidden_dims,\n",
- " reduction_threshold=reduction_threshold,\n",
- ")\n",
- "\n",
- "\n",
- "plot_convergence_vs_k(\n",
- " conv_grid=conv_grid,\n",
- " k_values=k_values,\n",
- " hidden_dims=hidden_dims,\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=False,\n",
- " log_y=True,\n",
- " reduction_threshold=reduction_threshold,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3e814b11",
- "metadata": {},
- "source": [
- "### Curve plot : Normalized Convergence Steps vs Sequence Length for different hidden dimensions\n",
- "- x-axis: sequence length\n",
- "- y-axis: normalized convergence steps ($\\text{steps} / |G|^k$)\n",
- "- different curves for different hidden dimensions $H$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e116a0e1",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_convergence_vs_k_normalized(\n",
- " conv_grid,\n",
- " k_values,\n",
- " hidden_dims,\n",
- " p: int = 100, # Vocabulary size\n",
- " batch_size: int = 1000, # Batch size used in training\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=False,\n",
- " log_y=True,\n",
- " reduction_threshold=0.9,\n",
- "):\n",
- " \"\"\"\n",
- " Plot fraction of data space seen to convergence vs sequence length k.\n",
- "\n",
- " Normalizes steps to convergence by the data space size (p^k) to show\n",
- " what fraction of the data space needs to be seen for convergence.\n",
- "\n",
- " Args:\n",
- " conv_grid: 2D array (len(hidden_dims), len(k_values)) with convergence steps\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " p: Vocabulary size (data space per token)\n",
- " batch_size: Batch size used during training\n",
- " save_path: Where to save the plot\n",
- " show: Whether to display the plot\n",
- " log_x: Whether to use log scale for x-axis\n",
- " log_y: Whether to use log scale for y-axis\n",
- " reduction_threshold: Threshold used for convergence\n",
- " \"\"\"\n",
- " fig, ax = plt.subplots(figsize=(10, 6))\n",
- "\n",
- " # Use a nice sequential colormap for different widths\n",
- " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(hidden_dims)))\n",
- "\n",
- " for i, (h, color) in enumerate(zip(hidden_dims, colors)):\n",
- " # Extract convergence steps for this hidden dim across all k values\n",
- " steps_for_h = conv_grid[i, :]\n",
- "\n",
- " # Only plot converged points\n",
- " converged_mask = ~np.isnan(steps_for_h)\n",
- " k_converged = np.array(k_values)[converged_mask]\n",
- " steps_converged = steps_for_h[converged_mask]\n",
- "\n",
- " if len(steps_converged) > 0:\n",
- " # Normalize by data space size for each k\n",
- " # samples_seen = steps * batch_size\n",
- " # fraction = samples_seen / p^k\n",
- " fractions = []\n",
- " for k_val, steps_val in zip(k_converged, steps_converged):\n",
- " data_space_size = p**k_val\n",
- " samples_seen = steps_val * batch_size\n",
- " fraction = samples_seen / data_space_size\n",
- " fractions.append(fraction)\n",
- "\n",
- " # Plot with line and markers\n",
- " ax.plot(\n",
- " k_converged,\n",
- " fractions,\n",
- " color=color,\n",
- " marker=\"o\",\n",
- " markersize=7,\n",
- " linewidth=2.5,\n",
- " label=f\"h={h:,}\",\n",
- " markeredgewidth=0.5,\n",
- " markeredgecolor=\"white\",\n",
- " )\n",
- "\n",
- " # Formatting\n",
- " ax.set_xlabel(\"Sequence Length (k)\", fontsize=14)\n",
- " ax.set_ylabel(\"Data points seen / $p^k$ to convergence\", fontsize=14)\n",
- " ax.set_title(\n",
- " f\"Data Efficiency to {reduction_threshold*100}% Convergence\",\n",
- " fontsize=16,\n",
- " )\n",
- "\n",
- " if log_y:\n",
- " ax.set_yscale(\"log\")\n",
- " if log_x:\n",
- " ax.set_xscale(\"log\")\n",
- " else:\n",
- " # Make k values discrete on x-axis\n",
- " ax.set_xticks(k_values)\n",
- " ax.set_xticklabels(k_values)\n",
- "\n",
- " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
- " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
- " print(f\"Saved to {save_path}\")\n",
- "\n",
- " if show:\n",
- " plt.show()\n",
- " else:\n",
- " plt.close()\n",
- "\n",
- " return fig, ax\n",
- "\n",
- "\n",
- "reduction_threshold = 0.9\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir,\n",
- " k_values,\n",
- " hidden_dims,\n",
- " reduction_threshold=reduction_threshold,\n",
- ")\n",
- "\n",
- "\n",
- "plot_convergence_vs_k_normalized(\n",
- " conv_grid=conv_grid,\n",
- " k_values=k_values,\n",
- " hidden_dims=hidden_dims,\n",
- " p=10, # Your vocabulary size\n",
- " batch_size=1000, # Your batch size\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=False,\n",
- " log_y=True,\n",
- " reduction_threshold=reduction_threshold,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e0f758af",
- "metadata": {},
- "source": [
- "### Curve plot: Loss vs Training Steps for different sequence lengths, fixed hidden dimension\n",
- "- x-axis: # training steps\n",
- "- y-axis: training loss\n",
- "- different curves for different sequence lengths"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7e809783",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from pathlib import Path\n",
- "\n",
- "\n",
- "def plot_loss_curves_fixed_width(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dim: int = 77760,\n",
- " seed: int = 0,\n",
- " save_path: str = None,\n",
- " show: bool = True,\n",
- " log_x: bool = True,\n",
- " log_y: bool = True,\n",
- "):\n",
- " \"\"\"\n",
- " Plot loss curves for different sequence lengths k with fixed hidden dimension.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to sweep directory\n",
- " k_values: List of k values to plot (e.g., [2, 3, 4, 5, 6, 7, 8])\n",
- " hidden_dim: Fixed hidden dimension (default: 77760)\n",
- " seed: Which seed to plot (default: 0)\n",
- " save_path: Where to save the plot\n",
- " show: Whether to display the plot\n",
- " log_x: Whether to use log scale for x-axis\n",
- " log_y: Whether to use log scale for y-axis\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " # Create figure\n",
- " fig, ax = plt.subplots(figsize=(10, 6))\n",
- "\n",
- " # Use a nice sequential colormap (plasma, magma, cividis, YlOrRd, etc.)\n",
- " colors = plt.cm.plasma(\n",
- " np.linspace(0.15, 0.95, len(k_values))\n",
- " ) # Avoid too light/dark\n",
- "\n",
- " for k, color in zip(k_values, colors):\n",
- " run_dir = sweep_path / f\"k{k}_h{hidden_dim}\" / f\"seed_{seed}\"\n",
- " loss_file = run_dir / \"train_loss_history.npy\"\n",
- "\n",
- " if not loss_file.exists():\n",
- " print(f\"Warning: No data found for k={k}, h={hidden_dim}\")\n",
- " continue\n",
- "\n",
- " # Load loss history\n",
- " loss_history = np.load(loss_file)\n",
- " steps = np.arange(len(loss_history))\n",
- "\n",
- " # Plot\n",
- " ax.plot(steps, loss_history, color=color, lw=2.5, label=f\"k={k}\")\n",
- "\n",
- " # Formatting\n",
- " ax.set_xlabel(\"Training Steps\", fontsize=14)\n",
- " ax.set_ylabel(\"Training Loss\", fontsize=14)\n",
- " ax.set_title(f\"Loss vs Training Steps (h={hidden_dim:,})\", fontsize=16)\n",
- " if log_x:\n",
- " ax.set_xscale(\"log\")\n",
- " if log_y:\n",
- " ax.set_yscale(\"log\")\n",
- " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
- " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
- " print(f\"Saved to {save_path}\")\n",
- "\n",
- " if show:\n",
- " plt.show()\n",
- " else:\n",
- " plt.close()\n",
- "\n",
- " return fig, ax\n",
- "\n",
- "\n",
- "plot_loss_curves_fixed_width(\n",
- " sweep_dir=sweep_dir,\n",
- " k_values=[2, 3, 4, 5, 6, 7, 8],\n",
- " hidden_dim=6**2,\n",
- " seed=0,\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=True,\n",
- " log_y=True,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e5cd8b97",
- "metadata": {},
- "outputs": [],
- "source": [
- "def compute_spikiness_metrics_upward_only(loss_history):\n",
- " \"\"\"\n",
- " Compute spikiness focusing ONLY on upward jumps (loss increases).\n",
- "\n",
- " This separates:\n",
- " - Fast learning (large downward jumps) = STABLE\n",
- " - Instability (upward jumps) = UNSTABLE/SPIKY\n",
- " \"\"\"\n",
- " log_loss = np.log10(loss_history + 1e-10)\n",
- " log_changes = np.diff(log_loss) # Can be positive or negative\n",
- "\n",
- " # Separate upward (bad) from downward (good)\n",
- " upward_spikes = log_changes[log_changes > 0] # Loss INCREASES\n",
- " downward_drops = log_changes[log_changes < 0] # Loss DECREASES\n",
- "\n",
- " metrics = {}\n",
- "\n",
- " # Count how many steps are upward vs downward\n",
- " metrics[\"n_upward\"] = len(upward_spikes)\n",
- " metrics[\"n_downward\"] = len(downward_drops)\n",
- " metrics[\"frac_upward\"] = (\n",
- " len(upward_spikes) / len(log_changes) if len(log_changes) > 0 else 0\n",
- " )\n",
- "\n",
- " if len(upward_spikes) > 0:\n",
- " # Metrics based ONLY on upward spikes (instability)\n",
- " metrics[\"upward_p95\"] = np.percentile(upward_spikes, 95)\n",
- " metrics[\"upward_p999\"] = np.percentile(upward_spikes, 99.9)\n",
- " metrics[\"upward_max\"] = np.max(upward_spikes)\n",
- " metrics[\"upward_mean\"] = np.mean(upward_spikes)\n",
- " metrics[\"upward_std\"] = np.std(upward_spikes)\n",
- " else:\n",
- " # Perfectly monotonic decrease (never went up!)\n",
- " metrics[\"upward_p95\"] = 0.0\n",
- " metrics[\"upward_p999\"] = 0.0\n",
- " metrics[\"upward_max\"] = 0.0\n",
- " metrics[\"upward_mean\"] = 0.0\n",
- " metrics[\"upward_std\"] = 0.0\n",
- "\n",
- " if len(downward_drops) > 0:\n",
- " # For reference: how fast is it learning?\n",
- " metrics[\"downward_p95\"] = np.percentile(\n",
- " np.abs(downward_drops), 95\n",
- " ) # Large drops = fast\n",
- " metrics[\"downward_mean\"] = np.mean(np.abs(downward_drops))\n",
- " else:\n",
- " metrics[\"downward_p95\"] = 0.0\n",
- " metrics[\"downward_mean\"] = 0.0\n",
- "\n",
- " # Ratio: upward spikes vs downward progress\n",
- " if metrics[\"downward_mean\"] > 0:\n",
- " metrics[\"spike_to_progress_ratio\"] = (\n",
- " metrics[\"upward_mean\"] / metrics[\"downward_mean\"]\n",
- " )\n",
- " else:\n",
- " metrics[\"spike_to_progress_ratio\"] = (\n",
- " np.inf if metrics[\"upward_mean\"] > 0 else 0.0\n",
- " )\n",
- "\n",
- " # Late-stage upward spikes (last 20%)\n",
- " cutoff = int(0.8 * len(log_changes))\n",
- " late_changes = log_changes[cutoff:]\n",
- " late_upward = late_changes[late_changes > 0]\n",
- "\n",
- " if len(late_upward) > 0:\n",
- " metrics[\"late_upward_p95\"] = np.percentile(late_upward, 95)\n",
- " metrics[\"late_upward_max\"] = np.max(late_upward)\n",
- " else:\n",
- " metrics[\"late_upward_p95\"] = 0.0\n",
- " metrics[\"late_upward_max\"] = 0.0\n",
- "\n",
- " return metrics\n",
- "\n",
- "\n",
- "def plot_loss_curves_with_upward_metrics(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dim: int = 36,\n",
- " seed: int = 0,\n",
- " save_path: str = None,\n",
- " show: bool = True,\n",
- " log_x: bool = True,\n",
- " log_y: bool = True,\n",
- "):\n",
- " \"\"\"Plot loss curves with metrics focused on upward spikes only.\"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " fig, ax = plt.subplots(figsize=(10, 6))\n",
- " colors = plt.cm.plasma(np.linspace(0.15, 0.95, len(k_values)))\n",
- "\n",
- " metrics_data = []\n",
- "\n",
- " for k, color in zip(k_values, colors):\n",
- " run_dir = sweep_path / f\"k{k}_h{hidden_dim}\" / f\"seed_{seed}\"\n",
- " loss_file = run_dir / \"train_loss_history.npy\"\n",
- "\n",
- " if not loss_file.exists():\n",
- " print(f\"Warning: No data found for k={k}, h={hidden_dim}\")\n",
- " continue\n",
- "\n",
- " loss_history = np.load(loss_file)\n",
- " steps = np.arange(len(loss_history))\n",
- "\n",
- " ax.plot(steps, loss_history, color=color, lw=2.5, label=f\"k={k}\")\n",
- "\n",
- " # Compute upward-only metrics\n",
- " metrics = compute_spikiness_metrics_upward_only(loss_history)\n",
- " metrics[\"k\"] = k\n",
- " metrics[\"n_steps\"] = len(loss_history)\n",
- " metrics_data.append(metrics)\n",
- "\n",
- " # Formatting\n",
- " ax.set_xlabel(\"Training Steps\", fontsize=14)\n",
- " ax.set_ylabel(\"Training Loss\", fontsize=14)\n",
- " ax.set_title(f\"Loss vs Training Steps (h={hidden_dim:,})\", fontsize=16)\n",
- " if log_x:\n",
- " ax.set_xscale(\"log\")\n",
- " if log_y:\n",
- " ax.set_yscale(\"log\")\n",
- " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
- " ax.legend(fontsize=11, framealpha=0.9, loc=\"best\")\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
- " print(f\"Saved to {save_path}\")\n",
- "\n",
- " if show:\n",
- " plt.show()\n",
- " else:\n",
- " plt.close()\n",
- "\n",
- " # Print metrics\n",
- " df = pd.DataFrame(metrics_data)\n",
- " col_order = [\n",
- " \"k\",\n",
- " \"n_steps\",\n",
- " \"frac_upward\",\n",
- " \"upward_p95\",\n",
- " \"upward_p999\",\n",
- " \"upward_max\",\n",
- " \"late_upward_p95\",\n",
- " \"spike_to_progress_ratio\",\n",
- " \"downward_p95\",\n",
- " ]\n",
- " df = df[col_order]\n",
- "\n",
- " print(\"\\n\" + \"=\" * 100)\n",
- " print(f\"UPWARD SPIKE METRICS (h={hidden_dim})\")\n",
- " print(\"=\" * 100)\n",
- " print(\"\\nMetric Definitions:\")\n",
- " print(\" frac_upward : Fraction of steps where loss INCREASED\")\n",
- " print(\" upward_p95 : 95th percentile of upward jumps (instability)\")\n",
- " print(\" upward_p999 : 99.9th percentile of upward jumps\")\n",
- " print(\" upward_max : Worst upward spike\")\n",
- " print(\" late_upward_p95 : 95th percentile of upward jumps in last 20%\")\n",
- " print(\n",
- " \" spike_to_progress_ratio : Mean upward / mean downward (higher = more unstable)\"\n",
- " )\n",
- " print(\n",
- " \" downward_p95 : 95th percentile of downward jumps (learning speed)\"\n",
- " )\n",
- " print(\"\\n\" + \"-\" * 100)\n",
- "\n",
- " pd.set_option(\"display.max_columns\", None)\n",
- " pd.set_option(\"display.width\", None)\n",
- " pd.set_option(\"display.float_format\", \"{:.4f}\".format)\n",
- " print(df.to_string(index=False))\n",
- " print(\"=\" * 100 + \"\\n\")\n",
- "\n",
- " return fig, ax, df\n",
- "\n",
- "\n",
- "# Call it:\n",
- "fig, ax, metrics_df = plot_loss_curves_with_upward_metrics(\n",
- " sweep_dir=sweep_dir,\n",
- " k_values=[2, 3, 4, 5, 6, 7, 8],\n",
- " hidden_dim=6**6,\n",
- " seed=0,\n",
- " save_path=None,\n",
- " show=True,\n",
- " log_x=True,\n",
- " log_y=True,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bd0b5026",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_spikiness(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dims: list,\n",
- "):\n",
- " \"\"\"\n",
- " Compute fraction of training steps where loss increased (instability).\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean frac_upward across seeds\n",
- " std_grid: 2D array with standard deviations\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " frac_upwards = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " log_loss = np.log10(loss_history + 1e-10)\n",
- " log_changes = np.diff(log_loss)\n",
- "\n",
- " # Fraction of steps where loss went UP\n",
- " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n",
- " frac_upwards.append(frac_upward)\n",
- "\n",
- " if frac_upwards:\n",
- " grid[i, j] = np.mean(frac_upwards)\n",
- " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6d858f63",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Compute stability grid\n",
- "stability_grid, stability_std = load_sweep_results_grid_spikiness(\n",
- " sweep_dir, k_values, hidden_dims\n",
- ")\n",
- "\n",
- "# Plot\n",
- "plt.figure(figsize=(8, 6.5))\n",
- "plt.imshow(stability_grid, aspect=\"equal\", cmap=\"viridis\") # , norm=LogNorm())\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "ytick_labels = [f\"$6^{i+1}$ ({val:,})\" for i, val in enumerate(hidden_dims)]\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels)\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "plt.gca().invert_yaxis()\n",
- "plt.colorbar(label=\"Training Spikiness\")\n",
- "plt.title(\"Training Spikiness\")\n",
- "plt.tight_layout()\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "298e758e",
- "metadata": {},
- "source": [
- "# Varying group size (num frequencies)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "74cd5103",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_convergence_3d(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dims: list,\n",
- " num_frequencies: int,\n",
- " reduction_threshold: float = 0.99,\n",
- "):\n",
- " \"\"\"\n",
- " Load sweep results and compute steps to convergence for 3D sweeps over k, h, and f.\n",
- "\n",
- " This function is designed for sweeps that include num_frequencies as a parameter,\n",
- " using directory naming format: k{k}_h{h}_f{f}\n",
- "\n",
- " Convergence is defined as reaching `reduction_threshold` loss reduction\n",
- " (e.g., 0.99 = 99% reduction from initial loss).\n",
- "\n",
- " If convergence is not reached, the grid point is set to NaN (blacked out).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " num_frequencies: Number of frequencies (f parameter)\n",
- " reduction_threshold: Fraction of loss reduction to consider converged\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(k_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, k in enumerate(k_values):\n",
- " exp_name = f\"k{k}_h{h}_f{num_frequencies}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Collect convergence steps from all seeds\n",
- " convergence_steps = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " initial_loss = loss_history[0]\n",
- "\n",
- " if initial_loss > 0:\n",
- " # Compute reduction at each step\n",
- " reductions = 1 - loss_history / initial_loss\n",
- "\n",
- " # Find first step where reduction >= threshold\n",
- " converged_mask = reductions >= reduction_threshold\n",
- " if np.any(converged_mask):\n",
- " step = np.argmax(converged_mask) # First True\n",
- " convergence_steps.append(step)\n",
- " # else: Never converged - don't add to list\n",
- "\n",
- " if convergence_steps:\n",
- " grid[i, j] = np.mean(convergence_steps)\n",
- " std_grid[i, j] = (\n",
- " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n",
- " )\n",
- " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "418d1ac0",
- "metadata": {},
- "source": [
- "## num_freq = 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "01cf6f19",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example using the new 3D sweep (k, h, f)\n",
- "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n",
- "\n",
- "# Define parameter values for the new sweep\n",
- "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n",
- "num_frequencies = 2 # Set this to the frequency value you want to visualize\n",
- "\n",
- "# Load convergence data for a specific frequency\n",
- "reduction_threshold = 0.5\n",
- "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n",
- " new_sweep_dir,\n",
- " k_values_new,\n",
- " hidden_dims_new,\n",
- " reduction_threshold=reduction_threshold,\n",
- " num_frequencies=num_frequencies,\n",
- ")\n",
- "\n",
- "# Plot the heatmap\n",
- "plt.figure(figsize=(8, 6))\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "plt.xticks(range(len(k_values_new)), k_values_new)\n",
- "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ec9b06b2",
- "metadata": {},
- "source": [
- "## num_freq = 3"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "32e86647",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example using the new 3D sweep (k, h, f)\n",
- "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n",
- "\n",
- "# Define parameter values for the new sweep\n",
- "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n",
- "num_frequencies = 3 # Set this to the frequency value you want to visualize\n",
- "\n",
- "# Load convergence data for a specific frequency\n",
- "reduction_threshold = 0.5\n",
- "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n",
- " new_sweep_dir,\n",
- " k_values_new,\n",
- " hidden_dims_new,\n",
- " reduction_threshold=reduction_threshold,\n",
- " num_frequencies=num_frequencies,\n",
- ")\n",
- "\n",
- "# Plot the heatmap\n",
- "plt.figure(figsize=(8, 6))\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "plt.xticks(range(len(k_values_new)), k_values_new)\n",
- "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "44a4e9c7",
- "metadata": {},
- "source": [
- "## num_freq = 4"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e77f1a4e",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example using the new 3D sweep (k, h, f)\n",
- "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n",
- "\n",
- "# Define parameter values for the new sweep\n",
- "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n",
- "num_frequencies = 4 # Set this to the frequency value you want to visualize\n",
- "\n",
- "# Load convergence data for a specific frequency\n",
- "reduction_threshold = 0.5\n",
- "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n",
- " new_sweep_dir,\n",
- " k_values_new,\n",
- " hidden_dims_new,\n",
- " reduction_threshold=reduction_threshold,\n",
- " num_frequencies=num_frequencies,\n",
- ")\n",
- "\n",
- "# Plot the heatmap\n",
- "plt.figure(figsize=(8, 6))\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "plt.xticks(range(len(k_values_new)), k_values_new)\n",
- "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "aa9ae86c",
- "metadata": {},
- "source": [
- "## num_freq = 5"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0ce18ab9",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example using the new 3D sweep (k, h, f)\n",
- "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n",
- "\n",
- "# Define parameter values for the new sweep\n",
- "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n",
- "num_frequencies = 5 # Set this to the frequency value you want to visualize\n",
- "\n",
- "# Load convergence data for a specific frequency\n",
- "reduction_threshold = 0.5\n",
- "conv_grid_new, conv_std_new = load_sweep_results_grid_convergence_3d(\n",
- " new_sweep_dir,\n",
- " k_values_new,\n",
- " hidden_dims_new,\n",
- " reduction_threshold=reduction_threshold,\n",
- " num_frequencies=num_frequencies,\n",
- ")\n",
- "\n",
- "# Plot the heatmap\n",
- "plt.figure(figsize=(8, 6))\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid_new, aspect=\"equal\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\")\n",
- "plt.ylabel(\"Hidden Dimension\")\n",
- "plt.xticks(range(len(k_values_new)), k_values_new)\n",
- "plt.yticks(range(len(hidden_dims_new)), hidden_dims_new)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- "plt.title(f\"Steps to Convergence (f={num_frequencies}, black = did not converge)\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f8585f79",
- "metadata": {},
- "source": [
- "### Grid plot: Convergence vs k for different num_frequencies, across different hidden dimensions\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97c43fc0",
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_convergence_vs_k_grid_by_frequency(\n",
- " sweep_dir: str,\n",
- " k_values: list,\n",
- " hidden_dims: list,\n",
- " num_frequencies_values: list,\n",
- " reduction_threshold: float = 0.99,\n",
- " figsize=(18, 12),\n",
- " log_x=False,\n",
- " log_y=True,\n",
- " save_path=None,\n",
- " show=True,\n",
- "):\n",
- " \"\"\"\n",
- " Create a grid of plots showing convergence vs k for different frequencies.\n",
- "\n",
- " Each subplot corresponds to a different hidden dimension.\n",
- " Within each subplot, different curves represent different num_frequencies values.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k_values: List of k (sequence length) values\n",
- " hidden_dims: List of hidden dimension values (one subplot per hidden dim)\n",
- " num_frequencies_values: List of num_frequencies values to compare\n",
- " reduction_threshold: Threshold for convergence definition\n",
- " figsize: Figure size tuple\n",
- " log_x: Whether to use log scale for x-axis\n",
- " log_y: Whether to use log scale for y-axis\n",
- " save_path: Where to save the plot\n",
- " show: Whether to display the plot\n",
- "\n",
- " Returns:\n",
- " fig, axes: Matplotlib figure and axes objects\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " # Determine grid layout (2x3 or 3x2 based on number of hidden dims)\n",
- " n_plots = len(hidden_dims)\n",
- " if n_plots == 6:\n",
- " nrows, ncols = 2, 3\n",
- " elif n_plots == 4:\n",
- " nrows, ncols = 2, 2\n",
- " else:\n",
- " # General case: aim for squarish layout\n",
- " ncols = int(np.ceil(np.sqrt(n_plots)))\n",
- " nrows = int(np.ceil(n_plots / ncols))\n",
- "\n",
- " fig, axes = plt.subplots(nrows, ncols, figsize=figsize)\n",
- " axes_flat = axes.flatten() if n_plots > 1 else [axes]\n",
- "\n",
- " # Use a nice colormap for different frequencies\n",
- " colors = plt.cm.viridis(np.linspace(0.15, 0.85, len(num_frequencies_values)))\n",
- "\n",
- " for idx, h in enumerate(hidden_dims):\n",
- " ax = axes_flat[idx]\n",
- "\n",
- " # For each frequency, plot convergence vs k\n",
- " for f_idx, num_freq in enumerate(num_frequencies_values):\n",
- " convergence_steps_for_k = []\n",
- " k_values_converged = []\n",
- "\n",
- " for k in k_values:\n",
- " exp_name = f\"k{k}_h{h}_f{num_freq}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Collect convergence steps from all seeds\n",
- " convergence_steps = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " initial_loss = loss_history[0]\n",
- "\n",
- " if initial_loss > 0:\n",
- " reductions = 1 - loss_history / initial_loss\n",
- " converged_mask = reductions >= reduction_threshold\n",
- " if np.any(converged_mask):\n",
- " step = np.argmax(converged_mask)\n",
- " convergence_steps.append(step)\n",
- "\n",
- " # Take mean across seeds if any converged\n",
- " if convergence_steps:\n",
- " k_values_converged.append(k)\n",
- " convergence_steps_for_k.append(np.mean(convergence_steps))\n",
- "\n",
- " # Plot this frequency's curve\n",
- " if len(k_values_converged) > 0:\n",
- " ax.plot(\n",
- " k_values_converged,\n",
- " convergence_steps_for_k,\n",
- " color=colors[f_idx],\n",
- " marker=\"o\",\n",
- " markersize=6,\n",
- " linewidth=2,\n",
- " label=f\"f={num_freq}\",\n",
- " markeredgewidth=0.5,\n",
- " markeredgecolor=\"white\",\n",
- " )\n",
- "\n",
- " # Formatting for this subplot\n",
- " ax.set_xlabel(\"Sequence Length (k)\", fontsize=11)\n",
- " ax.set_ylabel(\"Steps to Convergence\", fontsize=11)\n",
- " ax.set_title(f\"h = {h:,}\", fontsize=13, fontweight=\"bold\")\n",
- "\n",
- " if log_y:\n",
- " ax.set_yscale(\"log\")\n",
- " if log_x:\n",
- " ax.set_xscale(\"log\")\n",
- " else:\n",
- " # Make k values discrete on x-axis\n",
- " ax.set_xticks(k_values)\n",
- " ax.set_xticklabels(k_values)\n",
- "\n",
- " ax.grid(True, alpha=0.3, which=\"both\", linestyle=\"--\", linewidth=0.5)\n",
- " ax.legend(fontsize=9, framealpha=0.9, loc=\"best\")\n",
- "\n",
- " # Hide any unused subplots\n",
- " for idx in range(n_plots, len(axes_flat)):\n",
- " axes_flat[idx].axis(\"off\")\n",
- "\n",
- " # Overall title\n",
- " fig.suptitle(\n",
- " f\"Convergence vs Sequence Length by Number of Frequencies\\n\"\n",
- " f\"({reduction_threshold*100:.0f}% Loss Reduction Threshold)\",\n",
- " fontsize=16,\n",
- " fontweight=\"bold\",\n",
- " y=0.995,\n",
- " )\n",
- "\n",
- " plt.tight_layout()\n",
- "\n",
- " if save_path:\n",
- " plt.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n",
- " print(f\"Saved to {save_path}\")\n",
- "\n",
- " if show:\n",
- " plt.show()\n",
- " else:\n",
- " plt.close()\n",
- "\n",
- " return fig, axes"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "acd00a0a",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Example usage: Create grid plot\n",
- "new_sweep_dir = \"/home/facosta/group-agf/sweeps/sweep_mlp_scaling_20251212_172318\"\n",
- "\n",
- "k_values_new = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims_new = [6, 36, 216, 1296, 7776, 46656]\n",
- "num_frequencies_values = [2, 3, 4, 5] # All frequency values to compare\n",
- "\n",
- "plot_convergence_vs_k_grid_by_frequency(\n",
- " sweep_dir=new_sweep_dir,\n",
- " k_values=k_values_new,\n",
- " hidden_dims=hidden_dims_new,\n",
- " num_frequencies_values=num_frequencies_values,\n",
- " reduction_threshold=0.5,\n",
- " figsize=(18, 12),\n",
- " log_x=False,\n",
- " log_y=True,\n",
- " save_path=None,\n",
- " show=True,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "85c094f2",
- "metadata": {},
- "source": [
- "## p=2 experiments"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3a80c6fc",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define sweep directory and parameters\n",
- "sweep_dir = \"/home/facosta/group-agf/sweeps/p2_scaling_sweep_20251215_205347\"\n",
- "\n",
- "# Parameters from p2_scaling_sweep.yaml\n",
- "k_values = [2, 3, 4, 5, 6, 7, 8]\n",
- "hidden_dims = [\n",
- " 4,\n",
- " 8,\n",
- " 16,\n",
- " 32,\n",
- " 64,\n",
- " 128,\n",
- " 256,\n",
- " 512,\n",
- " 1024,\n",
- " 2048,\n",
- " 4096,\n",
- " 8192,\n",
- " 16384,\n",
- " 32768,\n",
- " 65536,\n",
- "]\n",
- "\n",
- "# Load convergence data\n",
- "reduction_threshold = 0.6 # 90% loss reduction\n",
- "conv_grid, conv_std = load_sweep_results_grid_convergence(\n",
- " sweep_dir, k_values, hidden_dims, reduction_threshold=reduction_threshold\n",
- ")\n",
- "\n",
- "\n",
- "from matplotlib.colors import LogNorm\n",
- "\n",
- "# Plot the heatmap\n",
- "plt.figure(figsize=(10, 8))\n",
- "cmap = plt.cm.viridis_r.copy()\n",
- "cmap.set_bad(color=\"black\")\n",
- "plt.imshow(conv_grid, aspect=\"auto\", cmap=cmap, norm=LogNorm())\n",
- "\n",
- "plt.xlabel(\"Sequence Length (k)\", fontsize=12)\n",
- "plt.ylabel(\"Hidden Dimension (h)\", fontsize=12)\n",
- "plt.xticks(range(len(k_values)), k_values)\n",
- "\n",
- "# Create y-tick labels with both power notation and actual values for larger dims\n",
- "ytick_labels = []\n",
- "for h in hidden_dims:\n",
- " if h >= 1024:\n",
- " power = int(np.log2(h))\n",
- " ytick_labels.append(f\"$2^{{{power}}}$ ({h:,})\")\n",
- " else:\n",
- " ytick_labels.append(f\"{h}\")\n",
- "\n",
- "plt.yticks(range(len(hidden_dims)), ytick_labels, fontsize=9)\n",
- "plt.gca().invert_yaxis()\n",
- "\n",
- "# Add theoretical boundary line (h > p^(k-1), where p=2)\n",
- "# For p=2: boundary at h = 2^(k-1), so h=1,2,4,8,16,32,64\n",
- "x_step = np.arange(len(k_values)) - 0.5\n",
- "# Find y index where h = 2^(k-1) for each k\n",
- "y_boundary = []\n",
- "for i, k in enumerate(k_values):\n",
- " boundary_h = 2 ** (k - 1)\n",
- " # Find closest hidden_dim index\n",
- " try:\n",
- " y_idx = hidden_dims.index(boundary_h)\n",
- " except ValueError:\n",
- " # If exact match not found, find closest\n",
- " y_idx = np.argmin(np.abs(np.array(hidden_dims) - boundary_h))\n",
- " y_boundary.append(y_idx)\n",
- "\n",
- "plt.step(\n",
- " x_step,\n",
- " y_boundary,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=r\"Theory boundary ($h > 2^{k-1}$)\",\n",
- ")\n",
- "\n",
- "plt.legend(loc=\"upper left\", fontsize=11, frameon=True)\n",
- "plt.colorbar(label=f\"Steps to {reduction_threshold*100:.0f}% Convergence\")\n",
- "plt.title(\n",
- " f\"p=2 Scaling: Steps to {reduction_threshold*100:.0f}% Convergence\\n(black = did not converge)\",\n",
- " fontsize=13,\n",
- " fontweight=\"bold\",\n",
- ")\n",
- "plt.tight_layout()\n",
- "plt.show()\n",
- "\n",
- "# Print summary statistics\n",
- "n_not_converged = np.sum(np.isnan(conv_grid))\n",
- "n_converged = np.sum(~np.isnan(conv_grid))\n",
- "total = conv_grid.size\n",
- "\n",
- "print(f\"\\n{'='*60}\")\n",
- "print(f\"CONVERGENCE SUMMARY (p=2 scaling)\")\n",
- "print(f\"{'='*60}\")\n",
- "print(f\"Converged: {n_converged:3d} ({100*n_converged/total:.1f}%)\")\n",
- "print(f\"Did not converge: {n_not_converged:3d} ({100*n_not_converged/total:.1f}%)\")\n",
- "print(f\"Total experiments: {total:3d}\")\n",
- "print(f\"{'='*60}\\n\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b1e10a36",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/seq_mlp_group_size.ipynb b/notebooks/seq_mlp_group_size.ipynb
deleted file mode 100644
index a30e7ec..0000000
--- a/notebooks/seq_mlp_group_size.ipynb
+++ /dev/null
@@ -1,1393 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "c8c5c4b6",
- "metadata": {},
- "source": [
- "# MLP Scaling: $H$ vs $|G|$ \n",
- "\n",
- "Hidden neurons vs group size scaling experiments."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "155908c2",
- "metadata": {},
- "source": [
- "## Set up"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7fc4c5b6",
- "metadata": {},
- "outputs": [],
- "source": [
- "# autoreload\n",
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "# jupyter black formatter\n",
- "%load_ext jupyter_black\n",
- "\n",
- "import subprocess\n",
- "import os\n",
- "import sys\n",
- "\n",
- "gitroot_path = subprocess.check_output(\n",
- " [\"git\", \"rev-parse\", \"--show-toplevel\"], universal_newlines=True\n",
- ").strip()\n",
- "\n",
- "os.chdir(gitroot_path)\n",
- "print(\"Working directory: \", os.getcwd())\n",
- "\n",
- "if gitroot_path not in sys.path:\n",
- " sys.path.insert(0, gitroot_path)\n",
- "print(\"Directory added to path: \", gitroot_path)\n",
- "\n",
- "import yaml\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from pathlib import Path"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "9831010d",
- "metadata": {},
- "source": [
- "## Specify experiment directory"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b9f8fc25",
- "metadata": {},
- "outputs": [],
- "source": [
- "# sweep_dir = \"/home/facosta/group-agf/sweeps/onehot_scaling_sweep_20251215_175955\"\n",
- "sweep_dir = \"/home/facosta/group-agf/sweep_results/onehot_scaling_sweep_20260112_022012\"\n",
- "print(os.path.exists(sweep_dir))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d8342c22",
- "metadata": {},
- "source": [
- "### Steps to Convergence"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bc6dd932",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_convergence_p_h(\n",
- " sweep_dir: str,\n",
- " k: int,\n",
- " p_values: list,\n",
- " hidden_dims: list,\n",
- " reduction_threshold: float = 0.99,\n",
- " max_p: int = None,\n",
- "):\n",
- " \"\"\"\n",
- " Load sweep results and compute steps to convergence for p vs hidden_dim sweep.\n",
- "\n",
- " Updated for experiment naming: k{k}_p{p}_h{h}\n",
- " Only loads completed experiments (checks for run_summary.yaml).\n",
- "\n",
- " Convergence is defined as reaching `reduction_threshold` loss reduction\n",
- " (e.g., 0.99 = 99% reduction from initial loss).\n",
- "\n",
- " If convergence is not reached, the grid point is set to NaN (blacked out).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k: Sequence length parameter (2, 3, or 4)\n",
- " p_values: List of p (group size) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " reduction_threshold: Fraction of loss reduction to consider converged\n",
- " max_p: Maximum p value to include (filters incomplete experiments)\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean steps to convergence (NaN if didn't converge)\n",
- " Shape: (len(hidden_dims), len(p_values))\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, p in enumerate(p_values):\n",
- " # Filter by max_p if specified\n",
- " if max_p is not None and p > max_p:\n",
- " continue\n",
- "\n",
- " exp_name = f\"k{k}_p{p}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Check if experiment is completed (has run_summary.yaml)\n",
- " seed_dir = exp_dir / \"seed_0\"\n",
- " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n",
- " continue # Skip incomplete experiments\n",
- "\n",
- " # Collect convergence steps from all seeds\n",
- " convergence_steps = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " initial_loss = loss_history[0]\n",
- "\n",
- " if initial_loss > 0:\n",
- " # Compute reduction at each step\n",
- " reductions = 1 - loss_history / initial_loss\n",
- "\n",
- " # Find first step where reduction >= threshold\n",
- " converged_mask = reductions >= reduction_threshold\n",
- " if np.any(converged_mask):\n",
- " step = np.argmax(converged_mask) # First True\n",
- " convergence_steps.append(step)\n",
- " # else: Never converged - don't add to list\n",
- "\n",
- " if convergence_steps:\n",
- " grid[i, j] = np.mean(convergence_steps)\n",
- " std_grid[i, j] = (\n",
- " np.std(convergence_steps) if len(convergence_steps) > 1 else 0.0\n",
- " )\n",
- " # else: No seeds converged - grid[i,j] remains NaN (blacked out)\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9a87f24d",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_final_loss_p_h(\n",
- " sweep_dir: str,\n",
- " k: int,\n",
- " p_values: list,\n",
- " hidden_dims: list,\n",
- " max_p: int = None,\n",
- "):\n",
- " \"\"\"\n",
- " Load sweep results and compute final training loss for p vs hidden_dim sweep.\n",
- "\n",
- " Updated for experiment naming: k{k}_p{p}_h{h}\n",
- " Only loads completed experiments (checks for run_summary.yaml).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k: Sequence length parameter (2, 3, or 4)\n",
- " p_values: List of p (group size) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " max_p: Maximum p value to include (filters incomplete experiments)\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean final training loss (NaN if experiment incomplete)\n",
- " Shape: (len(hidden_dims), len(p_values))\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, p in enumerate(p_values):\n",
- " # Filter by max_p if specified\n",
- " if max_p is not None and p > max_p:\n",
- " continue\n",
- "\n",
- " exp_name = f\"k{k}_p{p}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Check if experiment is completed (has run_summary.yaml)\n",
- " seed_dir = exp_dir / \"seed_0\"\n",
- " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n",
- " continue # Skip incomplete experiments\n",
- "\n",
- " # Collect final losses from all seeds\n",
- " final_losses = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " if len(loss_history) > 0:\n",
- " final_loss = loss_history[-1] # Last value\n",
- " final_losses.append(final_loss)\n",
- "\n",
- " if final_losses:\n",
- " grid[i, j] = np.mean(final_losses)\n",
- " std_grid[i, j] = np.std(final_losses) if len(final_losses) > 1 else 0.0\n",
- " # else: No seeds found - grid[i,j] remains NaN (blacked out)\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3bb53f80",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_training_loss_curves_p(\n",
- " sweep_dir: str,\n",
- " k: int,\n",
- " hidden_dim: int,\n",
- " p_values: list,\n",
- "):\n",
- " \"\"\"\n",
- " Load training loss histories for different group sizes (p) with fixed k and hidden_dim.\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k: Sequence length parameter (fixed)\n",
- " hidden_dim: Hidden dimension (fixed)\n",
- " p_values: List of p (group size) values to plot\n",
- "\n",
- " Returns:\n",
- " curves: Dictionary mapping p -> list of loss histories (one per seed)\n",
- " Each loss history is a numpy array\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " curves = {}\n",
- "\n",
- " for p in p_values:\n",
- " exp_name = f\"k{k}_p{p}_h{hidden_dim}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Check if experiment is completed\n",
- " seed_dir = exp_dir / \"seed_0\"\n",
- " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n",
- " continue # Skip incomplete experiments\n",
- "\n",
- " # Collect loss histories from all seeds\n",
- " loss_histories = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " if len(loss_history) > 0:\n",
- " loss_histories.append(loss_history)\n",
- "\n",
- " if loss_histories:\n",
- " curves[p] = loss_histories\n",
- "\n",
- " return curves"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "bf14dee1",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_final_val_loss_p_h(\n",
- " sweep_dir: str,\n",
- " k: int,\n",
- " p_values: list,\n",
- " hidden_dims: list,\n",
- " max_p: int = None,\n",
- "):\n",
- " \"\"\"\n",
- " Load sweep results and compute final validation loss for p vs hidden_dim sweep.\n",
- "\n",
- " Updated for experiment naming: k{k}_p{p}_h{h}\n",
- " Only loads completed experiments (checks for run_summary.yaml).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k: Sequence length parameter (2, 3, or 4)\n",
- " p_values: List of p (group size) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " max_p: Maximum p value to include (filters incomplete experiments)\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean final validation loss (NaN if experiment incomplete)\n",
- " Shape: (len(hidden_dims), len(p_values))\n",
- " std_grid: 2D array with standard deviations across seeds\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, p in enumerate(p_values):\n",
- " # Filter by max_p if specified\n",
- " if max_p is not None and p > max_p:\n",
- " continue\n",
- "\n",
- " exp_name = f\"k{k}_p{p}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Check if experiment is completed (has run_summary.yaml)\n",
- " seed_dir = exp_dir / \"seed_0\"\n",
- " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n",
- " continue # Skip incomplete experiments\n",
- "\n",
- " # Collect final validation losses from all seeds\n",
- " final_losses = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"val_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " if len(loss_history) > 0:\n",
- " final_loss = loss_history[-1] # Last value\n",
- " final_losses.append(final_loss)\n",
- "\n",
- " if final_losses:\n",
- " grid[i, j] = np.mean(final_losses)\n",
- " std_grid[i, j] = np.std(final_losses) if len(final_losses) > 1 else 0.0\n",
- " # else: No seeds found - grid[i,j] remains NaN (blacked out)\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "42ce6ffd",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define parameter values from the sweep config\n",
- "# Filter to p <= 55 for completed experiments\n",
- "p_values = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]\n",
- "hidden_dims = [80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120]\n",
- "k_values = [2, 3] # , 4] # Different k values to plot separately"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "7bf99dee",
- "metadata": {},
- "source": [
- "### Plot steps to convergence grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "522570f5",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load convergence data for each k value separately\n",
- "reduction_threshold = 0.90\n",
- "max_p = 70 # Only visualize completed experiments (p <= 55)\n",
- "\n",
- "from matplotlib.colors import LogNorm\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " conv_grid, conv_std = load_sweep_results_grid_convergence_p_h(\n",
- " sweep_dir,\n",
- " k,\n",
- " p_values,\n",
- " hidden_dims,\n",
- " reduction_threshold=reduction_threshold,\n",
- " max_p=max_p,\n",
- " )\n",
- "\n",
- " # Filter p values - only show p <= max_p\n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Plot convergence heatmap: p (group size) vs hidden_dim\n",
- " plt.figure(figsize=(12, 8))\n",
- " cmap = plt.cm.viridis_r.copy()\n",
- " cmap.set_bad(color=\"black\")\n",
- " # Set extent to align cells with tick positions\n",
- " # extent: [left, right, bottom, top] in data coordinates\n",
- " plt.imshow(\n",
- " conv_grid[:, : len(p_values_filtered)],\n",
- " aspect=\"equal\",\n",
- " cmap=cmap,\n",
- " norm=LogNorm(),\n",
- " )\n",
- "\n",
- " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- " plt.xticks(\n",
- " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n",
- " )\n",
- "\n",
- " # Set y-axis ticks (hidden dimensions)\n",
- " plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- " plt.gca().invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- "\n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1)) * reduction_threshold\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= upper_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1) * reduction_threshold\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= lower_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=4,\n",
- " linestyle=\"-\",\n",
- " label=f\"Upper boundary ($H$ = $(k+1) \\\\cdot 2^{{k-1}} |G| x {reduction_threshold}$ = {upper_boundary_coeff * reduction_threshold} * |G|) \",\n",
- " )\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"white\",\n",
- " linewidth=4,\n",
- " linestyle=\"-\",\n",
- " label=f\"Lower boundary ($H$ = $2^{{k-1}} |G| x {reduction_threshold}$ = {lower_boundary_coeff * reduction_threshold} * |G|) \",\n",
- " )\n",
- "\n",
- " # Place legend outside the plot area\n",
- " plt.legend(\n",
- " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n",
- " )\n",
- "\n",
- " plt.colorbar(label=f\"Steps to {reduction_threshold*100}% Convergence\")\n",
- " plt.title(\n",
- " f\"Steps to {reduction_threshold*100}% Convergence: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = did not converge, p ≤ {max_p})\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "28e479df",
- "metadata": {},
- "source": [
- "### Final Training Loss\n",
- " "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "06fb5f82",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load final training loss data for each k value separately\n",
- "max_p = 70 # Only visualize completed experiments (p <= 55)\n",
- "\n",
- "from matplotlib.colors import LogNorm\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " loss_grid, loss_std = load_sweep_results_grid_final_loss_p_h(\n",
- " sweep_dir,\n",
- " k,\n",
- " p_values,\n",
- " hidden_dims,\n",
- " max_p=max_p,\n",
- " )\n",
- "\n",
- " # Filter p values - only show p <= max_p\n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Plot final loss heatmap: p (group size) vs hidden_dim\n",
- " plt.figure(figsize=(12, 8))\n",
- " cmap = plt.cm.viridis_r.copy()\n",
- " cmap.set_bad(color=\"black\")\n",
- " # Set extent to align cells with tick positions\n",
- " # extent: [left, right, bottom, top] in data coordinates\n",
- " plt.imshow(\n",
- " loss_grid[:, : len(p_values_filtered)],\n",
- " aspect=\"equal\",\n",
- " cmap=cmap,\n",
- " norm=LogNorm(),\n",
- " )\n",
- "\n",
- " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- " plt.xticks(\n",
- " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n",
- " )\n",
- "\n",
- " # Set y-axis ticks (hidden dimensions)\n",
- " plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- " plt.gca().invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- "\n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= upper_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1)\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= lower_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"magenta\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " # Place legend outside the plot area\n",
- " plt.legend(\n",
- " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n",
- " )\n",
- "\n",
- " plt.colorbar(label=\"Final Training Loss\")\n",
- " plt.title(\n",
- " f\"Final Training Loss: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$)\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "93d367cb",
- "metadata": {},
- "source": [
- "### Training Loss Curves by Group Size\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5ed8f9c0",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Plot training loss curves for different group sizes\n",
- "# Specify the hidden dimension to use\n",
- "hidden_dim = 160 # Change this to plot different hidden dimensions\n",
- "\n",
- "# Use all available p values (or filter as needed)\n",
- "p_values_to_plot = [p for p in p_values if p <= 55] # Adjust max_p as needed\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " # Load training loss curves for different p values\n",
- " curves = load_training_loss_curves_p(\n",
- " sweep_dir,\n",
- " k,\n",
- " hidden_dim,\n",
- " p_values_to_plot,\n",
- " )\n",
- "\n",
- " if not curves:\n",
- " print(f\"No data found for k={k}, H={hidden_dim}\")\n",
- " continue\n",
- "\n",
- " # Create plot\n",
- " plt.figure(figsize=(10, 8))\n",
- "\n",
- " # Plot each group size as a separate curve\n",
- " # Use a colormap to distinguish different p values\n",
- " colors = plt.cm.viridis(np.linspace(0, 1, len(curves)))\n",
- "\n",
- " for i, (p, loss_histories) in enumerate(sorted(curves.items())):\n",
- " # Plot mean curve with shaded error bars\n",
- " # Find the maximum length to align all curves\n",
- " max_len = max(len(hist) for hist in loss_histories)\n",
- "\n",
- " # Pad shorter histories with NaN or last value\n",
- " aligned_histories = []\n",
- " for hist in loss_histories:\n",
- " if len(hist) < max_len:\n",
- " padded = np.full(max_len, np.nan)\n",
- " padded[: len(hist)] = hist\n",
- " aligned_histories.append(padded)\n",
- " else:\n",
- " aligned_histories.append(hist)\n",
- "\n",
- " aligned_histories = np.array(aligned_histories)\n",
- "\n",
- " # Compute mean and std across seeds\n",
- " mean_loss = np.nanmean(aligned_histories, axis=0)\n",
- " std_loss = np.nanstd(aligned_histories, axis=0)\n",
- "\n",
- " # Create step array (1-indexed for log scale)\n",
- " steps = np.arange(1, len(mean_loss) + 1)\n",
- "\n",
- " # Plot mean curve\n",
- " plt.loglog(\n",
- " steps,\n",
- " mean_loss,\n",
- " color=colors[i],\n",
- " linewidth=2,\n",
- " label=f\"$|G|={p}$\",\n",
- " )\n",
- "\n",
- " # Plot shaded error region (optional, can be commented out if too cluttered)\n",
- " # plt.fill_between(\n",
- " # steps,\n",
- " # mean_loss - std_loss,\n",
- " # mean_loss + std_loss,\n",
- " # color=colors[i],\n",
- " # alpha=0.2,\n",
- " # )\n",
- "\n",
- " plt.xlabel(\"Training Steps\", fontsize=14)\n",
- " plt.ylabel(\"Training Loss\", fontsize=14)\n",
- " plt.title(\n",
- " f\"Training Loss Curves: Group Size $|G|$ vs Steps\\n($k={k}$, $H={hidden_dim}$)\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.legend(loc=\"best\", fontsize=10, ncol=2)\n",
- " plt.grid(True, alpha=0.3, which=\"both\")\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "96be1620",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load final validation loss data for each k value separately\n",
- "max_p = 60 # Only visualize completed experiments (p <= 55)\n",
- "\n",
- "from matplotlib.colors import LogNorm\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " loss_grid, loss_std = load_sweep_results_grid_final_val_loss_p_h(\n",
- " sweep_dir,\n",
- " k,\n",
- " p_values,\n",
- " hidden_dims,\n",
- " max_p=max_p,\n",
- " )\n",
- "\n",
- " # Filter p values - only show p <= max_p\n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Plot final validation loss heatmap: p (group size) vs hidden_dim\n",
- " plt.figure(figsize=(12, 8))\n",
- " cmap = plt.cm.viridis_r.copy()\n",
- " cmap.set_bad(color=\"black\")\n",
- " # Set extent to align cells with tick positions\n",
- " # extent: [left, right, bottom, top] in data coordinates\n",
- " plt.imshow(\n",
- " loss_grid[:, : len(p_values_filtered)],\n",
- " aspect=\"equal\",\n",
- " cmap=cmap,\n",
- " norm=LogNorm(),\n",
- " )\n",
- "\n",
- " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- " plt.xticks(\n",
- " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n",
- " )\n",
- "\n",
- " # Set y-axis ticks (hidden dimensions)\n",
- " plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- " plt.gca().invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- "\n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= upper_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1)\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " # Find the first H that satisfies H >= lower_boundary_coeff * p\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1]) # Extend for step plot\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"blue\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " # Place legend outside the plot area\n",
- " plt.legend(\n",
- " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n",
- " )\n",
- "\n",
- " plt.colorbar(label=\"Final Validation Loss\")\n",
- " plt.title(\n",
- " f\"Final Validation Loss: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = incomplete experiment, p ≤ {max_p})\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d3111eeb",
- "metadata": {},
- "source": [
- "### Training Instability"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d743a392",
- "metadata": {},
- "outputs": [],
- "source": [
- "def load_sweep_results_grid_spikiness_p_h(\n",
- " sweep_dir: str, k: int, p_values: list, hidden_dims: list, max_p: int = None\n",
- "):\n",
- " \"\"\"\n",
- " Compute fraction of training steps where loss increased (instability) for p vs h sweeps.\n",
- "\n",
- " Updated for experiment naming: k{k}_p{p}_h{h}\n",
- " Only loads completed experiments (checks for run_summary.yaml).\n",
- "\n",
- " Args:\n",
- " sweep_dir: Path to the sweep directory\n",
- " k: Sequence length parameter (2, 3, or 4)\n",
- " p_values: List of p (group size) values\n",
- " hidden_dims: List of hidden dimension values\n",
- " max_p: Maximum p value to include (filters incomplete experiments)\n",
- "\n",
- " Returns:\n",
- " grid: 2D array with mean frac_upward across seeds\n",
- " Shape: (len(hidden_dims), len(p_values))\n",
- " std_grid: 2D array with standard deviations\n",
- " \"\"\"\n",
- " sweep_path = Path(sweep_dir)\n",
- "\n",
- " grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- " std_grid = np.full((len(hidden_dims), len(p_values)), np.nan)\n",
- "\n",
- " for i, h in enumerate(hidden_dims):\n",
- " for j, p in enumerate(p_values):\n",
- " # Filter by max_p if specified\n",
- " if max_p is not None and p > max_p:\n",
- " continue\n",
- "\n",
- " exp_name = f\"k{k}_p{p}_h{h}\"\n",
- " exp_dir = sweep_path / exp_name\n",
- "\n",
- " if not exp_dir.exists():\n",
- " continue\n",
- "\n",
- " # Check if experiment is completed\n",
- " seed_dir = exp_dir / \"seed_0\"\n",
- " if not seed_dir.exists() or not (seed_dir / \"run_summary.yaml\").exists():\n",
- " continue # Skip incomplete experiments\n",
- "\n",
- " frac_upwards = []\n",
- " for seed_dir in exp_dir.glob(\"seed_*\"):\n",
- " loss_file = seed_dir / \"train_loss_history.npy\"\n",
- " if loss_file.exists():\n",
- " loss_history = np.load(loss_file)\n",
- " log_loss = np.log10(loss_history + 1e-10)\n",
- " log_changes = np.diff(log_loss)\n",
- "\n",
- " # Fraction of steps where loss went UP\n",
- " frac_upward = np.sum(log_changes > 0) / len(log_changes)\n",
- " frac_upwards.append(frac_upward)\n",
- "\n",
- " if frac_upwards:\n",
- " grid[i, j] = np.mean(frac_upwards)\n",
- " std_grid[i, j] = np.std(frac_upwards) if len(frac_upwards) > 1 else 0.0\n",
- "\n",
- " return grid, std_grid"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "683b555c",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load spikiness data for each k value separately\n",
- "max_p = 70 # Only visualize completed experiments\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n",
- " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n",
- " )\n",
- "\n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Plot\n",
- " plt.figure(figsize=(12, 8))\n",
- " # Set extent to align cells with tick positions\n",
- " plt.imshow(\n",
- " spike_grid_p[:, : len(p_values_filtered)],\n",
- " aspect=\"equal\",\n",
- " cmap=\"plasma\",\n",
- " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5],\n",
- " )\n",
- " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- " plt.xticks(\n",
- " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n",
- " )\n",
- " plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- " plt.gca().invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- "\n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1)\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ = $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"white\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ = $2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " plt.legend(\n",
- " loc=\"upper center\", bbox_to_anchor=(0.5, -0.12), fontsize=12, frameon=True\n",
- " )\n",
- "\n",
- " plt.colorbar(label=\"Fraction of Upward Steps (Spikiness)\")\n",
- " plt.title(\n",
- " f\"Training Instability: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$)\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "772517ad",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load both metrics for each k value separately\n",
- "reduction_threshold = 0.99\n",
- "spikiness_threshold = 0.1\n",
- "max_p = 55 # Only visualize completed experiments\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " conv_grid_p, conv_std_p = load_sweep_results_grid_convergence_p_h(\n",
- " sweep_dir, k, p_values, hidden_dims, \n",
- " reduction_threshold=reduction_threshold,\n",
- " max_p=max_p\n",
- " )\n",
- " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n",
- " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n",
- " )\n",
- " \n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Create categorical grid: 0=black (no conv), 1=purple (spiky), 2=yellow (smooth)\n",
- " category_grid = np.full((len(hidden_dims), len(p_values_filtered)), 0.0) # Start with 0 (black)\n",
- "\n",
- " for i in range(len(hidden_dims)):\n",
- " for j in range(len(p_values_filtered)):\n",
- " converged = not np.isnan(conv_grid_p[i, j])\n",
- "\n",
- " if converged:\n",
- " spiky = spike_grid_p[i, j] > spikiness_threshold\n",
- " if spiky:\n",
- " category_grid[i, j] = 1.0 # Purple (spiky)\n",
- " else:\n",
- " category_grid[i, j] = 2.0 # Yellow (smooth)\n",
- " # else stays 0.0 (black, did not converge)\n",
- "\n",
- " # Plot\n",
- " fig, ax = plt.subplots(figsize=(12, 8))\n",
- "\n",
- " # Custom colormap: black -> purple -> yellow\n",
- " from matplotlib.colors import ListedColormap\n",
- "\n",
- " colors = [\"black\", \"purple\", \"yellow\"]\n",
- " cmap = ListedColormap(colors)\n",
- "\n",
- " # Set extent to align cells with tick positions\n",
- " im = ax.imshow(\n",
- " category_grid, \n",
- " aspect=\"auto\", \n",
- " cmap=cmap, \n",
- " vmin=0, \n",
- " vmax=2,\n",
- " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5]\n",
- " )\n",
- "\n",
- " ax.set_xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " ax.set_ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- "\n",
- " # Set x-axis ticks (p values)\n",
- " ax.set_xticks(range(len(p_values_filtered)))\n",
- " ax.set_xticklabels(p_values_filtered, rotation=45, ha=\"center\")\n",
- "\n",
- " # Set y-axis ticks (hidden dimensions)\n",
- " ax.set_yticks(range(len(hidden_dims)))\n",
- " ax.set_yticklabels(hidden_dims)\n",
- " ax.invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- " \n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1)\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " ax.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " )\n",
- " \n",
- " ax.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"blue\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " # Create custom legend\n",
- " from matplotlib.patches import Patch\n",
- "\n",
- " legend_elements = [\n",
- " Patch(facecolor=\"black\", label=\"Did not converge\"),\n",
- " Patch(facecolor=\"purple\", label=f\"Spiky (frac_up > {spikiness_threshold})\"),\n",
- " Patch(facecolor=\"yellow\", label=\"Smooth convergence\"),\n",
- " plt.Line2D([0], [0], color=\"r\", linewidth=3, linestyle=\"--\", label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\"),\n",
- " plt.Line2D([0], [0], color=\"b\", linewidth=3, linestyle=\"--\", label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\"),\n",
- " ]\n",
- "\n",
- " ax.legend(handles=legend_elements, loc=\"upper left\", fontsize=11, frameon=True)\n",
- "\n",
- " ax.set_title(\n",
- " f\"Convergence & Spikiness: $|G|$ vs $H$ ($k={k}$)\\nThresholds: {reduction_threshold*100}% convergence, {spikiness_threshold} spikiness (p ≤ {max_p})\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "53ed6f4f",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load both convergence and spikiness data for each k value separately\n",
- "reduction_threshold = 0.99\n",
- "max_p = 55 # Only visualize completed experiments\n",
- "\n",
- "# Create separate plots for each k value\n",
- "for k in k_values:\n",
- " conv_grid_p, conv_std_p = load_sweep_results_grid_convergence_p_h(\n",
- " sweep_dir,\n",
- " k,\n",
- " p_values,\n",
- " hidden_dims,\n",
- " reduction_threshold=reduction_threshold,\n",
- " max_p=max_p,\n",
- " )\n",
- " spike_grid_p, spike_std_p = load_sweep_results_grid_spikiness_p_h(\n",
- " sweep_dir, k, p_values, hidden_dims, max_p=max_p\n",
- " )\n",
- "\n",
- " p_values_filtered = [p for p in p_values if p <= max_p]\n",
- "\n",
- " # Mask spikiness grid: only show spikiness for converged runs\n",
- " spike_grid_masked = spike_grid_p.copy()\n",
- " for i in range(len(hidden_dims)):\n",
- " for j in range(len(p_values_filtered)):\n",
- " if np.isnan(conv_grid_p[i, j]):\n",
- " # Did not converge - set to NaN (will be black)\n",
- " spike_grid_masked[i, j] = np.nan\n",
- "\n",
- " # Plot with masked spikiness\n",
- " plt.figure(figsize=(12, 8))\n",
- "\n",
- " # Use colormap with black for NaN\n",
- " cmap_spike = plt.cm.plasma.copy()\n",
- " cmap_spike.set_bad(color=\"black\")\n",
- "\n",
- " # Set extent to align cells with tick positions\n",
- " plt.imshow(\n",
- " spike_grid_masked[:, : len(p_values_filtered)],\n",
- " aspect=\"auto\",\n",
- " cmap=cmap_spike,\n",
- " vmin=0,\n",
- " vmax=0.5,\n",
- " extent=[-0.5, len(p_values_filtered) - 0.5, len(hidden_dims) - 0.5, -0.5],\n",
- " )\n",
- " plt.xlabel(\"Group Size $|G|$\", fontsize=14)\n",
- " plt.ylabel(\"Hidden Dimension $H$\", fontsize=14)\n",
- " plt.xticks(\n",
- " range(len(p_values_filtered)), p_values_filtered, rotation=45, ha=\"center\"\n",
- " )\n",
- " plt.yticks(range(len(hidden_dims)), hidden_dims)\n",
- " plt.gca().invert_yaxis()\n",
- "\n",
- " # Theory boundaries\n",
- " x_step = np.arange(len(p_values_filtered) + 1) - 0.5\n",
- "\n",
- " # Upper boundary: H = (k+1)*2^{k-1} * |G|\n",
- " upper_boundary_coeff = (k + 1) * (2 ** (k - 1))\n",
- " y_step_upper = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= upper_boundary_coeff * p)\n",
- " if upper_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_upper.append(y_step_upper[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_upper = [y - 0.5 for y in y_step_upper]\n",
- "\n",
- " # Lower boundary: H = 2^{k-1} * |G|\n",
- " lower_boundary_coeff = 2 ** (k - 1)\n",
- " y_step_lower = [\n",
- " min(\n",
- " len(hidden_dims) - 1,\n",
- " (\n",
- " np.argmax(np.array(hidden_dims) >= lower_boundary_coeff * p)\n",
- " if lower_boundary_coeff * p <= max(hidden_dims)\n",
- " else len(hidden_dims) - 1\n",
- " ),\n",
- " )\n",
- " for p in p_values_filtered\n",
- " ]\n",
- " y_step_lower.append(y_step_lower[-1])\n",
- " # Convert to edge positions (subtract 0.5 to place at bottom edge of cells)\n",
- " y_step_lower = [y - 0.5 for y in y_step_lower]\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_upper,\n",
- " where=\"post\",\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " plt.step(\n",
- " x_step,\n",
- " y_step_lower,\n",
- " where=\"post\",\n",
- " color=\"blue\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n",
- " )\n",
- "\n",
- " # Custom legend\n",
- " from matplotlib.patches import Patch\n",
- "\n",
- " legend_elements = [\n",
- " Patch(facecolor=\"black\", label=\"Did not converge\"),\n",
- " Patch(facecolor=\"purple\", label=\"Low spikiness (~0)\"),\n",
- " Patch(facecolor=\"yellow\", label=\"High spikiness (~0.5)\"),\n",
- " plt.Line2D(\n",
- " [0],\n",
- " [0],\n",
- " color=\"red\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Upper boundary ($H$ ≥ $(k+1) \\\\cdot 2^{{k-1}} |G|$)\",\n",
- " ),\n",
- " plt.Line2D(\n",
- " [0],\n",
- " [0],\n",
- " color=\"white\",\n",
- " linewidth=3,\n",
- " linestyle=\"--\",\n",
- " label=f\"Lower boundary ($H$ ≥ $2^{{k-1}} |G|$)\",\n",
- " ),\n",
- " ]\n",
- " plt.legend(\n",
- " handles=legend_elements,\n",
- " loc=\"upper center\",\n",
- " bbox_to_anchor=(0.5, -0.12),\n",
- " fontsize=11,\n",
- " frameon=True,\n",
- " ncol=5,\n",
- " )\n",
- "\n",
- " plt.colorbar(label=\"Fraction of Upward Steps (Spikiness)\")\n",
- " plt.title(\n",
- " f\"Training Instability: Group Size $|G|$ vs Hidden Dimension $H$\\n($k={k}$, black = did not converge, p ≤ {max_p})\",\n",
- " fontsize=14,\n",
- " fontweight=\"bold\",\n",
- " )\n",
- " plt.tight_layout()\n",
- " plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3f30eb24",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/sequential_cnxcn.ipynb b/notebooks/sequential_cnxcn.ipynb
new file mode 100644
index 0000000..f957123
--- /dev/null
+++ b/notebooks/sequential_cnxcn.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Sequential Group Composition on $C_n \\times C_n$\n",
+ "\n",
+ "**Group:** Product of cyclic groups $C_n \\times C_n$ of order $n^2$. \n",
+ "**Task:** Given a sequence of $k$ group elements $g_1, \\ldots, g_k \\in C_n \\times C_n$, predict their cumulative product. \n",
+ "**Sequence length:** $k = 3$ (sequential composition). \n",
+ "**Architecture:** `QuadraticRNN` with quadratic recurrence. \n",
+ "**Key result:** The RNN composes elements sequentially in $k$ steps, exploiting associativity."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import random\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "import src.dataset as dataset\n",
+ "import src.model as model\n",
+ "import src.optimizer as optimizer\n",
+ "import src.power as power\n",
+ "import src.template as template\n",
+ "import src.train as train_mod\n",
+ "import src.viz as viz"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Configuration"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
+ "\n",
+ "seed = 0\n",
+ "random.seed(seed)\n",
+ "np.random.seed(seed)\n",
+ "torch.manual_seed(seed)\n",
+ "\n",
+ "p1, p2 = 7, 7 # Cn x Cn dimensions\n",
+ "p_flat = p1 * p2\n",
+ "k = 3 # Sequence length\n",
+ "hidden_dim = 50 if TEST_MODE else 200\n",
+ "epochs = 2 if TEST_MODE else 5\n",
+ "num_samples = 100 if TEST_MODE else 10000\n",
+ "batch_size = 64 if TEST_MODE else 1000\n",
+ "lr = 1e-3\n",
+ "init_scale = 1e-2\n",
+ "mode = \"sampled\"\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "FIGURES_DIR = \"figures\"\n",
+ "os.makedirs(FIGURES_DIR, exist_ok=True)\n",
+ "\n",
+ "print(f\"Group: C_{p1} x C_{p2}, order {p_flat}\")\n",
+ "print(f\"Sequence length: k={k}\")\n",
+ "print(f\"Device: {device}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Template and Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a 2D template with known Fourier structure\n",
+ "tpl_2d = template.unique_freqs_2d(p1, p2, n_freqs=5, seed=seed)\n",
+ "\n",
+ "# Build sequential dataset\n",
+ "X, Y, sequence_xy = dataset.build_modular_addition_sequence_dataset_2d(\n",
+ " p1, p2, tpl_2d, k, mode=mode, num_samples=num_samples,\n",
+ ")\n",
+ "\n",
+ "X_tensor = torch.tensor(X, dtype=torch.float32).to(device)\n",
+ "Y_tensor = torch.tensor(Y, dtype=torch.float32).to(device)\n",
+ "\n",
+ "ds = TensorDataset(X_tensor, Y_tensor)\n",
+ "dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
+ "\n",
+ "print(f\"Dataset: {len(ds)} samples\")\n",
+ "print(f\"X shape: {X_tensor.shape} (N, k, p1*p2)\")\n",
+ "print(f\"Y shape: {Y_tensor.shape} (N, p1*p2)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize template\n",
+ "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n",
+ "\n",
+ "im = ax1.imshow(tpl_2d, cmap=\"RdBu_r\")\n",
+ "ax1.set_title(f\"Template on $C_{{{p1}}} \\\\times C_{{{p2}}}$\")\n",
+ "ax1.set_xlabel(\"$C_{\" + str(p2) + \"}$ index\")\n",
+ "ax1.set_ylabel(\"$C_{\" + str(p1) + \"}$ index\")\n",
+ "plt.colorbar(im, ax=ax1)\n",
+ "\n",
+ "# Show 2D power spectrum\n",
+ "pwr_2d = power.get_power_2d(tpl_2d, no_freq=True)\n",
+ "im2 = ax2.imshow(np.log10(pwr_2d + 1e-12), cmap=\"hot\")\n",
+ "ax2.set_title(\"Log power spectrum\")\n",
+ "ax2.set_xlabel(\"Freq $k_2$\")\n",
+ "ax2.set_ylabel(\"Freq $k_1$\")\n",
+ "plt.colorbar(im2, ax=ax2)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_template.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model and Optimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = model.QuadraticRNN(\n",
+ " p=p_flat,\n",
+ " d=hidden_dim,\n",
+ " template=tpl_2d.flatten(),\n",
+ " init_scale=init_scale,\n",
+ ")\n",
+ "net = net.to(device)\n",
+ "\n",
+ "criterion = nn.MSELoss()\n",
+ "opt = optimizer.HybridRNNOptimizer(net, lr=lr)\n",
+ "\n",
+ "print(f\"Model: QuadraticRNN(p={p_flat}, d={hidden_dim}, k={k})\")\n",
+ "print(f\"Optimizer: HybridRNNOptimizer(lr={lr})\")\n",
+ "print(f\"Training for {epochs} epochs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(\n",
+ " net,\n",
+ " dataloader,\n",
+ " criterion,\n",
+ " opt,\n",
+ " epochs=epochs,\n",
+ " verbose_interval=1,\n",
+ " save_param_interval=1,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Training Loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "\n",
+ "# Plot training loss with theoretical levels\n",
+ "ax.plot(loss_history, lw=4, label=\"Train loss\")\n",
+ "\n",
+ "theory = power.theoretical_loss_levels_2d(tpl_2d)\n",
+ "for level in theory[\"levels\"]:\n",
+ " ax.axhline(y=level, color=\"black\", linestyle=\"--\", linewidth=1.5, zorder=-2)\n",
+ "\n",
+ "ax.set_xlabel(\"Epochs\", fontsize=18)\n",
+ "ax.set_ylabel(\"Train Loss\", fontsize=18)\n",
+ "ax.set_title(f\"Training loss (sequential $k={k}$ on $C_{{{p1}}} \\\\times C_{{{p2}}}$)\", fontsize=18)\n",
+ "viz.style_axes(ax)\n",
+ "ax.grid(False)\n",
+ "\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_loss.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Show predictions vs ground truth\n",
+ "net.load_state_dict(param_history[-1])\n",
+ "net.eval()\n",
+ "\n",
+ "n_examples = 3\n",
+ "indices = np.random.choice(len(Y_tensor), size=n_examples, replace=False)\n",
+ "\n",
+ "fig, axes = plt.subplots(n_examples, 2, figsize=(8, 3 * n_examples))\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " x_batch = X_tensor[indices]\n",
+ " preds = net(x_batch).detach().cpu().numpy()\n",
+ " truths = Y_tensor[indices].detach().cpu().numpy()\n",
+ "\n",
+ "for i in range(n_examples):\n",
+ " axes[i, 0].imshow(truths[i].reshape(p1, p2), cmap=\"RdBu_r\")\n",
+ " axes[i, 0].set_title(\"Ground truth\")\n",
+ " axes[i, 1].imshow(preds[i].reshape(p1, p2), cmap=\"RdBu_r\")\n",
+ " axes[i, 1].set_title(\"Prediction\")\n",
+ "\n",
+ "plt.suptitle(f\"Predictions (sequential $k={k}$ on $C_{{{p1}}} \\\\times C_{{{p2}}}$, epoch {final_epoch})\", fontsize=16)\n",
+ "plt.tight_layout()\n",
+ "plt.savefig(f\"{FIGURES_DIR}/sequential_cnxcn_predictions.pdf\", bbox_inches=\"tight\")\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "group-agf",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/template_bar.pdf b/notebooks/template_bar.pdf
deleted file mode 100644
index 63f21bf..0000000
Binary files a/notebooks/template_bar.pdf and /dev/null differ
diff --git a/notebooks/template_fft_bar.pdf b/notebooks/template_fft_bar.pdf
deleted file mode 100644
index 808eaf6..0000000
Binary files a/notebooks/template_fft_bar.pdf and /dev/null differ
diff --git a/notebooks/znz_znz.ipynb b/notebooks/znz_znz.ipynb
deleted file mode 100644
index 9e38ce6..0000000
--- a/notebooks/znz_znz.ipynb
+++ /dev/null
@@ -1,234 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "51d11caf-0971-4324-b63b-819b714a9c3c",
- "metadata": {},
- "source": [
- "# Learning Z/nZ x Z/nZ group actions\n",
- "This notebook is adapted from the `modular arithmetic` notebook, replacing `Z/nZ` group action with `Z/nZ x Z/nZ` group action "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80f249f1-6985-4c73-86cd-04e1adac3e8d",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import random\n",
- "import torch\n",
- "import os\n",
- "import torch.nn as nn\n",
- "import torch.optim as optim\n",
- "import shutil\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.cm as cm\n",
- "from matplotlib.animation import FuncAnimation\n",
- "from matplotlib.ticker import FormatStrFormatter\n",
- "from matplotlib.ticker import FuncFormatter\n",
- "from matplotlib.ticker import MaxNLocator\n",
- "\n",
- "import importlib\n",
- "import pickle\n",
- "\n",
- "import group_agf.binary_action_learning.models as models\n",
- "import group_agf.binary_action_learning.datasets as datasets\n",
- "import group_agf.binary_action_learning.power as power\n",
- "import group_agf.binary_action_learning.train as train\n",
- "import group_agf.binary_action_learning.plot as plot\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "da0f43df",
- "metadata": {},
- "source": [
- "# Define Dataset and Visualize"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f5dba48b",
- "metadata": {},
- "outputs": [],
- "source": [
- "from group_agf.binary_action_learning.default_config import verbose_interval\n",
- "import os\n",
- "\n",
- "# TEST_MODE: Set to reduce epochs for automated testing\n",
- "TEST_MODE = os.environ.get(\"NOTEBOOK_TEST_MODE\", \"0\") == \"1\"\n",
- "\n",
- "p = 3 if TEST_MODE else 5 # Reduced in test mode\n",
- "mnist_digit = 4\n",
- "dataset_fraction = 0.1 if TEST_MODE else 0.2 # Reduced in test mode\n",
- "template_type = 'mnist'\n",
- "seed = 47\n",
- "batch_size = 32 if TEST_MODE else 128 # Reduced in test mode\n",
- "hidden_size = 32 if TEST_MODE else 128 # Reduced in test mode\n",
- "lr = 0.001\n",
- "mom = 0.9\n",
- "init_scale = 1e-2\n",
- "epochs = 2 if TEST_MODE else 1000\n",
- "verbose_interval = max(1, epochs // 10)\n",
- "\n",
- "model_save_path = (\n",
- " f\"/tmp/adele/model_\"\n",
- " f\"p{p}_\"\n",
- " f\"digit{mnist_digit}_\"\n",
- " f\"frac{dataset_fraction}_\"\n",
- " f\"type{template_type}_\"\n",
- " f\"seed{seed}.pkl\"\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "615fe334",
- "metadata": {},
- "outputs": [],
- "source": [
- "template = datasets.choose_template(p, template_type, mnist_digit)\n",
- "group = 'cnxcn'\n",
- "\n",
- "top_frequency_plot = plot.plot_top_template_components(group, template, p)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a28838b7",
- "metadata": {},
- "outputs": [],
- "source": [
- "X, Y, translations = datasets.load_modular_addition_dataset_2d(p, template, fraction=dataset_fraction, random_state=seed, template_type=template_type)\n",
- "\n",
- "X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, p, device=None)\n",
- "\n",
- "dataset = TensorDataset(X, Y)\n",
- "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "111a2530",
- "metadata": {},
- "source": [
- "# Define Model and Train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "00b56ec1",
- "metadata": {},
- "outputs": [],
- "source": [
- "np.random.seed(seed)\n",
- "torch.manual_seed(seed)\n",
- "torch.cuda.manual_seed_all(seed) # if using GPU\n",
- "\n",
- "model = models.TwoLayerNet(p=p, hidden_size=hidden_size, nonlinearity='square', init_scale=init_scale, output_scale=1e0)\n",
- "model = model.to(device)\n",
- "loss = nn.MSELoss()\n",
- "optimizer = optim.Adam(model.parameters(), lr=lr, betas=(mom, 0.999))\n",
- "\n",
- "loss_history, accuracy_history, param_history = train.train(\n",
- " model,\n",
- " dataloader,\n",
- " loss,\n",
- " optimizer,\n",
- " epochs=epochs,\n",
- " verbose_interval=verbose_interval,\n",
- " model_save_path=model_save_path\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e3b19dcc",
- "metadata": {},
- "source": [
- "# Plot loss, power, and model output"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5989bcd9",
- "metadata": {},
- "outputs": [],
- "source": [
- "loss_plot = plot.plot_loss_curve(loss_history, template)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b1965705",
- "metadata": {},
- "outputs": [],
- "source": [
- "template_2d = template.reshape((p, p))\n",
- "power_over_training_plot = plot.plot_training_power_over_time(template_2d, model, device, param_history, X, p, save_path=None, show=False) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d0319335",
- "metadata": {},
- "outputs": [],
- "source": [
- "neuron_indices = list(range(20))\n",
- "group= 'cnxcn'\n",
- "print(neuron_indices)\n",
- "neuron_weights_plot = plot.plot_neuron_weights(group, model, p, neuron_indices=neuron_indices, show=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80fc35d9",
- "metadata": {},
- "outputs": [],
- "source": [
- "idx = 13\n",
- "plot.plot_model_outputs(p, model, X, Y, idx)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "296b3d56",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "group-agf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/test/test_notebooks.py b/test/test_notebooks.py
index a87daf6..9352ebb 100644
--- a/test/test_notebooks.py
+++ b/test/test_notebooks.py
@@ -22,6 +22,7 @@
import os
import subprocess
import sys
+import tempfile
from pathlib import Path
import pytest
@@ -39,19 +40,8 @@ def get_notebooks_dir():
# Notebooks to skip (with reasons)
SKIP_NOTEBOOKS = {
- # These notebooks have hardcoded paths to /home/facosta/ which don't exist
- "seq_mlp_group_size": "Has hardcoded paths to /home/facosta/ filesystem",
- "rnn_gagf": "Has hardcoded paths to /home/facosta/ filesystem",
- # These notebooks require pre-trained model files or external data
- "paper_figures": "Requires pre-trained model .pkl files not included in repo",
- # These notebooks have import/code issues that need separate debugging
- "2D": "Missing function: cannot import 'get_power_2d' from src.power",
- "znz_znz": "Missing function: datasets.choose_template() does not exist",
- "seq_mlp": "Plotting error: Invalid vmin/vmax values during visualization",
- # These notebooks have visualization code with hardcoded indices that fail with reduced p
- "C_n": "IndexError in visualization code when running with reduced parameters",
- "dihedral": "IndexError in visualization code when running with reduced parameters",
- "modular_arithmetic": "IndexError in visualization code when running with reduced parameters",
+ # Add notebooks here if they need to be skipped, e.g.:
+ # "notebook_name": "Reason for skipping",
}
@@ -89,32 +79,34 @@ def execute_notebook(notebook_path, env):
tuple: (success: bool, error_message: str or None)
"""
try:
- result = subprocess.run(
- [
- sys.executable,
- "-m",
- "jupyter",
- "nbconvert",
- "--to",
- "notebook",
- "--execute",
- "--ExecutePreprocessor.timeout=300", # 5 minute timeout per notebook
- "--ExecutePreprocessor.kernel_name=python3",
- "--output",
- "/dev/null",
- str(notebook_path),
- ],
- capture_output=True,
- text=True,
- env=env,
- cwd=str(get_repo_root()),
- timeout=360, # 6 minute overall timeout
- )
-
- if result.returncode != 0:
- error_msg = f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
- return False, error_msg
- return True, None
+ with tempfile.TemporaryDirectory() as tmpdir:
+ output_path = os.path.join(tmpdir, "output.ipynb")
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "jupyter",
+ "nbconvert",
+ "--to",
+ "notebook",
+ "--execute",
+ "--ExecutePreprocessor.timeout=300", # 5 minute timeout per notebook
+ "--ExecutePreprocessor.kernel_name=python3",
+ "--output",
+ output_path,
+ str(notebook_path),
+ ],
+ capture_output=True,
+ text=True,
+ env=env,
+ cwd=str(get_repo_root()),
+ timeout=360, # 6 minute overall timeout
+ )
+
+ if result.returncode != 0:
+ error_msg = f"STDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
+ return False, error_msg
+ return True, None
except subprocess.TimeoutExpired:
return False, "Notebook execution timed out (>6 minutes)"