diff --git a/gagf/rnns/__init__.py b/gagf/rnns/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/gagf/rnns/create_combined_power_plot.py b/gagf/rnns/create_combined_power_plot.py deleted file mode 100644 index ad75d80..0000000 --- a/gagf/rnns/create_combined_power_plot.py +++ /dev/null @@ -1,335 +0,0 @@ -#!/usr/bin/env python3 -""" -Create a combined 4x3 plot showing power spectrum evolution for k=2,3,4,5. -Each row corresponds to a k value, each column to a scale type (linear, log-x, log-log). -""" - -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -import yaml -from escnn.group import DihedralGroup - -from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_D3 -from gagf.rnns.model import SequentialMLP -from group_agf.binary_action_learning.group_fourier_transform import compute_group_fourier_coef - - -def load_run_data(run_dir): - """Load all necessary data from a run directory.""" - run_dir = Path(run_dir) - - # Load config - with open(run_dir / "config.yaml") as f: - config = yaml.safe_load(f) - - # Load template - template = np.load(run_dir / "template.npy") - - # Load parameter history - param_hist = torch.load(run_dir / "param_history.pt", map_location="cpu") - - # Get k value - k = config["data"]["k"] - - # Create model - D3 = DihedralGroup(N=3) - group_order = D3.order() - template_t = torch.tensor(template, dtype=torch.float32) - model = SequentialMLP( - p=group_order, - d=config["model"]["hidden_dim"], - template=template_t, - k=k, - init_scale=config["model"]["init_scale"], - ) - - # Generate evaluation data - X_eval, Y_eval, _ = build_modular_addition_sequence_dataset_D3( - template, k, mode="sampled", num_samples=100, return_all_outputs=False - ) - X_eval_t = torch.tensor(X_eval, dtype=torch.float32) - - # Compute template power - irreps = D3.irreps() - n_irreps = len(irreps) - template_power = np.zeros(n_irreps) - for i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, template, irrep) - template_power[i] = irrep.size * np.trace(fourier_coef.conj().T @ fourier_coef) - template_power = template_power / group_order - - # Compute param_save_indices - save_interval = config["training"].get("save_param_interval", 10) or 10 - param_save_indices = [i * save_interval for i in range(len(param_hist))] - - return { - "config": config, - "template": template, - "param_hist": param_hist, - "param_save_indices": param_save_indices, - "model": model, - "X_eval_t": X_eval_t, - "D3": D3, - "template_power": template_power, - "k": k, - } - - -def compute_power_evolution(run_data, num_checkpoints_to_sample=50, num_samples_for_power=100): - """Compute power evolution over training.""" - model = run_data["model"] - param_hist = run_data["param_hist"] - param_save_indices = run_data["param_save_indices"] - X_eval_t = run_data["X_eval_t"] - D3 = run_data["D3"] - - irreps = D3.irreps() - n_irreps = len(irreps) - - # Sample checkpoints - total_checkpoints = len(param_hist) - if total_checkpoints <= num_checkpoints_to_sample: - sampled_ckpt_indices = list(range(total_checkpoints)) - else: - sampled_ckpt_indices = np.linspace( - 0, total_checkpoints - 1, num_checkpoints_to_sample, dtype=int - ).tolist() - - epoch_numbers = [param_save_indices[i] for i in sampled_ckpt_indices] - - # Compute model output power at each checkpoint - model_powers = np.zeros((len(sampled_ckpt_indices), n_irreps)) - X_subset = X_eval_t[:num_samples_for_power] - - for i, ckpt_idx in enumerate(sampled_ckpt_indices): - model.load_state_dict(param_hist[ckpt_idx]) - model.eval() - - with torch.no_grad(): - outputs = model(X_subset) - outputs_np = outputs.cpu().numpy() - - powers = np.zeros((len(outputs_np), n_irreps)) - for sample_i, output in enumerate(outputs_np): - for irrep_i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, output, irrep) - powers[sample_i, irrep_i] = irrep.size * np.trace( - fourier_coef.conj().T @ fourier_coef - ) - powers = powers / D3.order() - model_powers[i] = np.mean(powers, axis=0) - - return epoch_numbers, model_powers - - -def create_combined_plot(run_dirs_dict, save_path): - """Create 4x3 combined plot.""" - # run_dirs_dict: {k: run_dir} - - fig = plt.figure(figsize=(18, 20)) - gs = fig.add_gridspec(5, 3, height_ratios=[0.15, 1, 1, 1, 1], hspace=0.3, wspace=0.3) - - # Top row for common parameters - ax_title = fig.add_subplot(gs[0, :]) - ax_title.axis("off") - - # Load first run to get common parameters - first_run_dir = list(run_dirs_dict.values())[0] - first_run_data = load_run_data(first_run_dir) - common_config = first_run_data["config"] - - # Extract common parameters - hidden_dim = common_config["model"]["hidden_dim"] - mode = common_config["data"]["mode"] - optimizer = common_config["training"]["optimizer"] - - # Create title with common parameters - title_text = "D3 Power Spectrum Evolution Over Training\n" - title_text += f"Common Parameters: hidden_dim={hidden_dim}, mode={mode}, optimizer={optimizer}" - ax_title.text( - 0.5, - 0.5, - title_text, - ha="center", - va="center", - fontsize=14, - fontweight="bold", - transform=ax_title.transAxes, - ) - - # Create axes for plots with shared x-axes for each column - axes = [] - # First create all axes in the first row (no sharing yet) - for col in range(3): - axes.append([fig.add_subplot(gs[1, col])]) - - # Then create remaining rows sharing x-axis with first row in each column - for row in range(1, 4): - for col in range(3): - axes[col].append(fig.add_subplot(gs[row + 1, col], sharex=axes[col][0])) - - # Convert to row-major format for easier indexing - axes = np.array([[axes[col][row] for col in range(3)] for row in range(4)]) - - k_values = sorted(run_dirs_dict.keys()) - - for row_idx, k in enumerate(k_values): - run_dir = run_dirs_dict[k] - print(f"Loading k={k} from {run_dir}...") - run_data = load_run_data(run_dir) - - epoch_numbers, model_powers = compute_power_evolution(run_data) - template_power = run_data["template_power"] - D3 = run_data["D3"] - irreps = D3.irreps() - config = run_data["config"] - - # Get row-specific parameters - learning_rate = config["training"]["learning_rate"] - init_scale = config["model"]["init_scale"] - - # Format init_scale nicely - if init_scale >= 1e-3: - init_scale_str = f"{init_scale:.0e}" - elif init_scale >= 1e-6: - init_scale_str = f"{init_scale:.1e}" - else: - init_scale_str = f"{init_scale:.2e}" - - # Format learning_rate nicely - if learning_rate >= 1e-3: - lr_str = f"{learning_rate:.0e}" - elif learning_rate >= 1e-6: - lr_str = f"{learning_rate:.1e}" - else: - lr_str = f"{learning_rate:.2e}" - - # Row label - row_label = f"k={k}, lr={lr_str}, init_scale={init_scale_str}" - - # Get top irreps - top_k_irreps = min(5, len(irreps)) - top_irrep_indices = np.argsort(template_power)[::-1][:top_k_irreps] - colors_line = plt.cm.tab10(np.linspace(0, 1, top_k_irreps)) - - # Filter for log scales - valid_mask = np.array(epoch_numbers) > 0 - valid_epochs = np.array(epoch_numbers)[valid_mask] - valid_model_powers = model_powers[valid_mask, :] - - # Column 1: Linear scales - ax = axes[row_idx, 0] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = model_powers[:, irrep_idx] - ax.plot( - epoch_numbers, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - if row_idx == 3: # Only bottom row shows xlabel - ax.set_xlabel("Epoch") - ax.set_ylabel("Power") - if row_idx == 0: - col_title = "Linear Scales" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 3: - ax.tick_params(labelbottom=False) - - # Column 2: Log x-axis - ax = axes[row_idx, 1] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - ax.plot( - valid_epochs, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - ax.set_xscale("log") - if row_idx == 3: # Only bottom row shows xlabel - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power") - if row_idx == 0: - col_title = "Log X-axis" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 3: - ax.tick_params(labelbottom=False) - - # Column 3: Log-log scales - ax = axes[row_idx, 2] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - power_mask = power_values > 0 - if np.any(power_mask): - ax.plot( - valid_epochs[power_mask], - power_values[power_mask], - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - if template_power[irrep_idx] > 0: - ax.axhline( - template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i] - ) - ax.set_xscale("log") - ax.set_yscale("log") - if row_idx == 3: # Only bottom row shows xlabel - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power (log scale)") - if row_idx == 0: - col_title = "Log-Log Scales" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 3: - ax.tick_params(labelbottom=False) - - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f"\n✓ Saved combined plot to {save_path}") - plt.close() - - -if __name__ == "__main__": - # Map k values to run directories - # k=2, k=3: 10000 epochs - # k=4, k=5: 20000 epochs - run_dirs_dict = { - 2: "runs/20260114_134641", - 3: "runs/20260114_134800", - 4: "runs/20260114_141256", - 5: "runs/20260114_141951", - } - - # Verify all directories exist - for k, run_dir in run_dirs_dict.items(): - if not os.path.exists(run_dir): - print(f"Warning: {run_dir} does not exist for k={k}") - - save_path = "runs/combined_power_spectrum_4x3.pdf" - create_combined_plot(run_dirs_dict, save_path) diff --git a/gagf/rnns/create_combined_power_plot_k4_k5.py b/gagf/rnns/create_combined_power_plot_k4_k5.py deleted file mode 100644 index ec79a6e..0000000 --- a/gagf/rnns/create_combined_power_plot_k4_k5.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python3 -""" -Create a combined 2x3 plot showing power spectrum evolution for k=4 and k=5. -Each row corresponds to a k value, each column to a scale type (linear, log-x, log-log). -""" - -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -import yaml -from escnn.group import DihedralGroup - -from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_D3 -from gagf.rnns.model import SequentialMLP -from group_agf.binary_action_learning.group_fourier_transform import compute_group_fourier_coef - - -def load_run_data(run_dir): - """Load all necessary data from a run directory.""" - run_dir = Path(run_dir) - - # Load config - with open(run_dir / "config.yaml") as f: - config = yaml.safe_load(f) - - # Load template - template = np.load(run_dir / "template.npy") - - # Load parameter history - param_hist = torch.load(run_dir / "param_history.pt", map_location="cpu") - - # Get k value - k = config["data"]["k"] - - # Create model - D3 = DihedralGroup(N=3) - group_order = D3.order() - template_t = torch.tensor(template, dtype=torch.float32) - model = SequentialMLP( - p=group_order, - d=config["model"]["hidden_dim"], - template=template_t, - k=k, - init_scale=config["model"]["init_scale"], - ) - - # Generate evaluation data - X_eval, Y_eval, _ = build_modular_addition_sequence_dataset_D3( - template, k, mode="sampled", num_samples=100, return_all_outputs=False - ) - X_eval_t = torch.tensor(X_eval, dtype=torch.float32) - - # Compute template power - irreps = D3.irreps() - n_irreps = len(irreps) - template_power = np.zeros(n_irreps) - for i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, template, irrep) - template_power[i] = irrep.size * np.trace(fourier_coef.conj().T @ fourier_coef) - template_power = template_power / group_order - - # Compute param_save_indices - save_interval = config["training"].get("save_param_interval", 10) or 10 - param_save_indices = [i * save_interval for i in range(len(param_hist))] - - return { - "config": config, - "template": template, - "param_hist": param_hist, - "param_save_indices": param_save_indices, - "model": model, - "X_eval_t": X_eval_t, - "D3": D3, - "template_power": template_power, - "k": k, - } - - -def compute_power_evolution(run_data, num_checkpoints_to_sample=50, num_samples_for_power=100): - """Compute power evolution over training.""" - model = run_data["model"] - param_hist = run_data["param_hist"] - param_save_indices = run_data["param_save_indices"] - X_eval_t = run_data["X_eval_t"] - D3 = run_data["D3"] - - irreps = D3.irreps() - n_irreps = len(irreps) - - # Sample checkpoints - total_checkpoints = len(param_hist) - if total_checkpoints <= num_checkpoints_to_sample: - sampled_ckpt_indices = list(range(total_checkpoints)) - else: - sampled_ckpt_indices = np.linspace( - 0, total_checkpoints - 1, num_checkpoints_to_sample, dtype=int - ).tolist() - - epoch_numbers = [param_save_indices[i] for i in sampled_ckpt_indices] - - # Compute model output power at each checkpoint - model_powers = np.zeros((len(sampled_ckpt_indices), n_irreps)) - X_subset = X_eval_t[:num_samples_for_power] - - for i, ckpt_idx in enumerate(sampled_ckpt_indices): - model.load_state_dict(param_hist[ckpt_idx]) - model.eval() - - with torch.no_grad(): - outputs = model(X_subset) - outputs_np = outputs.cpu().numpy() - - powers = np.zeros((len(outputs_np), n_irreps)) - for sample_i, output in enumerate(outputs_np): - for irrep_i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, output, irrep) - powers[sample_i, irrep_i] = irrep.size * np.trace( - fourier_coef.conj().T @ fourier_coef - ) - powers = powers / D3.order() - model_powers[i] = np.mean(powers, axis=0) - - return epoch_numbers, model_powers - - -def create_combined_plot(run_dirs_dict, save_path): - """Create 2x3 combined plot.""" - # run_dirs_dict: {k: run_dir} - - fig = plt.figure(figsize=(18, 10)) - gs = fig.add_gridspec(3, 3, height_ratios=[0.15, 1, 1], hspace=0.3, wspace=0.3) - - # Top row for common parameters - ax_title = fig.add_subplot(gs[0, :]) - ax_title.axis("off") - - # Load first run to get common parameters - first_run_dir = list(run_dirs_dict.values())[0] - first_run_data = load_run_data(first_run_dir) - common_config = first_run_data["config"] - - # Extract common parameters - hidden_dim = common_config["model"]["hidden_dim"] - mode = common_config["data"]["mode"] - optimizer = common_config["training"]["optimizer"] - - # Create title with common parameters - title_text = "D3 Power Spectrum Evolution Over Training\n" - title_text += f"Common Parameters: hidden_dim={hidden_dim}, mode={mode}, optimizer={optimizer}" - ax_title.text( - 0.5, - 0.5, - title_text, - ha="center", - va="center", - fontsize=14, - fontweight="bold", - transform=ax_title.transAxes, - ) - - # Create axes for plots with shared x-axes for each column - axes = [] - # First create all axes in the first row (no sharing yet) - for col in range(3): - axes.append([fig.add_subplot(gs[1, col])]) - - # Then create remaining rows sharing x-axis with first row in each column - for row in range(1, 2): - for col in range(3): - axes[col].append(fig.add_subplot(gs[row + 1, col], sharex=axes[col][0])) - - # Convert to row-major format for easier indexing - axes = np.array([[axes[col][row] for col in range(3)] for row in range(2)]) - - k_values = sorted(run_dirs_dict.keys()) - - for row_idx, k in enumerate(k_values): - run_dir = run_dirs_dict[k] - print(f"Loading k={k} from {run_dir}...") - run_data = load_run_data(run_dir) - - epoch_numbers, model_powers = compute_power_evolution(run_data) - template_power = run_data["template_power"] - D3 = run_data["D3"] - irreps = D3.irreps() - config = run_data["config"] - - # Get row-specific parameters - learning_rate = config["training"]["learning_rate"] - init_scale = config["model"]["init_scale"] - - # Format init_scale nicely - if init_scale >= 1e-3: - init_scale_str = f"{init_scale:.0e}" - elif init_scale >= 1e-6: - init_scale_str = f"{init_scale:.1e}" - else: - init_scale_str = f"{init_scale:.2e}" - - # Format learning_rate nicely - if learning_rate >= 1e-3: - lr_str = f"{learning_rate:.0e}" - elif learning_rate >= 1e-6: - lr_str = f"{learning_rate:.1e}" - else: - lr_str = f"{learning_rate:.2e}" - - # Row label - row_label = f"k={k}, lr={lr_str}, init_scale={init_scale_str}" - - # Get top irreps - top_k_irreps = min(5, len(irreps)) - top_irrep_indices = np.argsort(template_power)[::-1][:top_k_irreps] - colors_line = plt.cm.tab10(np.linspace(0, 1, top_k_irreps)) - - # Filter for log scales - valid_mask = np.array(epoch_numbers) > 0 - valid_epochs = np.array(epoch_numbers)[valid_mask] - valid_model_powers = model_powers[valid_mask, :] - - # Column 1: Linear scales - ax = axes[row_idx, 0] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = model_powers[:, irrep_idx] - ax.plot( - epoch_numbers, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - if row_idx == 1: # Only bottom row shows xlabel - ax.set_xlabel("Epoch") - ax.set_ylabel("Power") - if row_idx == 0: - col_title = "Linear Scales" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 1: - ax.tick_params(labelbottom=False) - - # Column 2: Log x-axis - ax = axes[row_idx, 1] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - ax.plot( - valid_epochs, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - ax.set_xscale("log") - if row_idx == 1: # Only bottom row shows xlabel - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power") - if row_idx == 0: - col_title = "Log X-axis" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 1: - ax.tick_params(labelbottom=False) - - # Column 3: Log-log scales - ax = axes[row_idx, 2] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - power_mask = power_values > 0 - if np.any(power_mask): - ax.plot( - valid_epochs[power_mask], - power_values[power_mask], - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - if template_power[irrep_idx] > 0: - ax.axhline( - template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i] - ) - ax.set_xscale("log") - ax.set_yscale("log") - if row_idx == 1: # Only bottom row shows xlabel - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power (log scale)") - if row_idx == 0: - col_title = "Log-Log Scales" - else: - col_title = "" - ax.set_title(f"{col_title}\n{row_label}", fontsize=12 if row_idx == 0 else 10) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - # Hide x-axis labels for non-bottom rows (they're shared) - if row_idx < 1: - ax.tick_params(labelbottom=False) - - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f"\n✓ Saved combined plot to {save_path}") - plt.close() - - -if __name__ == "__main__": - # Map k values to run directories - run_dirs_dict = { - 4: "runs/20260114_170639", - 5: "runs/20260114_170913", - } - - # Verify all directories exist - for k, run_dir in run_dirs_dict.items(): - if not os.path.exists(run_dir): - print(f"Warning: {run_dir} does not exist for k={k}") - - save_path = "runs/combined_power_spectrum_k4_k5_2x3.pdf" - create_combined_plot(run_dirs_dict, save_path) diff --git a/group_agf/binary_action_learning/main.py b/group_agf/binary_action_learning/main.py index 9e91236..f605f83 100644 --- a/group_agf/binary_action_learning/main.py +++ b/group_agf/binary_action_learning/main.py @@ -12,12 +12,12 @@ from torch.utils.data import DataLoader, TensorDataset import default_config -import group_agf.binary_action_learning.datasets as datasets -import group_agf.binary_action_learning.models as models -import group_agf.binary_action_learning.plot as plot -import group_agf.binary_action_learning.power as power import group_agf.binary_action_learning.train as train -from group_agf.binary_action_learning.optimizer import PerNeuronScaledSGD +import src.datasets as datasets +import src.model as models +import src.plot as plot +import src.power as power +from src.optimizers import PerNeuronScaledSGD today = datetime.date.today() diff --git a/group_agf/binary_action_learning/models.py b/group_agf/binary_action_learning/models.py deleted file mode 100644 index 9750cb3..0000000 --- a/group_agf/binary_action_learning/models.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn - - -class TwoLayerNet(nn.Module): - def __init__( - self, - group_size, - hidden_size=None, - nonlinearity="square", - init_scale=1.0, - output_scale=1.0, - ): - super().__init__() - self.group_size = group_size - if hidden_size is None: - # hidden_size = 6 * group_size - hidden_size = 50 * group_size - self.hidden_size = hidden_size - self.nonlinearity = nonlinearity - self.init_scale = init_scale - self.output_scale = output_scale - - # Initialize parameters - self.U = nn.Parameter( - self.init_scale - * torch.randn(hidden_size, self.group_size) - / np.sqrt(2 * self.group_size) - ) - self.V = nn.Parameter( - self.init_scale - * torch.randn(hidden_size, self.group_size) - / np.sqrt(2 * self.group_size) - ) - self.W = nn.Parameter( - self.init_scale * torch.randn(hidden_size, self.group_size) / np.sqrt(self.group_size) - ) # Second layer weights - - def forward(self, x): - # First layer (linear and combined) - x1 = x[:, : self.group_size] @ self.U.T - x2 = x[:, self.group_size :] @ self.V.T - x_combined = x1 + x2 - - # Apply nonlinearity activation - if self.nonlinearity == "relu": - x_combined = torch.relu(x_combined) - elif self.nonlinearity == "square": - x_combined = x_combined**2 - elif self.nonlinearity == "linear": - x_combined = x_combined - elif self.nonlinearity == "tanh": - x_combined = torch.tanh(x_combined) - elif self.nonlinearity == "gelu": - gelu = torch.nn.GELU() - x_combined = gelu(x_combined) - else: - raise ValueError(f"Invalid nonlinearity '{self.nonlinearity}' provided.") - - # Second layer (linear) - x_out = x_combined @ self.W - - # Feature learning scaling - x_out *= self.output_scale - - return x_out diff --git a/group_agf/binary_action_learning/optimizer.py b/group_agf/binary_action_learning/optimizer.py deleted file mode 100644 index fbb9ec7..0000000 --- a/group_agf/binary_action_learning/optimizer.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch - - -class PerNeuronScaledSGD(torch.optim.Optimizer): - """SGD with per-neuron learning rate scaling: - eta_i = ||theta_i||^(1 - k) - where theta_i = (W_in[i,:], W_drive[i,:], W_out[:,i]). - - Args: - model: the model to optimize - lr: the learning rate - k: the degree of the nonlinearity, for binary composition: square k=2. - See: Appendix B.5 A neuron-specific adaptive learning rate - yields instantaneous alignment of AGF paper - """ - - def __init__(self, model, lr=1e-2, k=2): - params = [model.U, model.V, model.W] - # Print shape of parameters with their names - print(f"model.U shape: {model.U.shape}") - print(f"model.V shape: {model.V.shape}") - print(f"model.W shape: {model.W.shape}") - super().__init__([{"params": params, "model": model}], dict(lr=lr, k=k)) - - @torch.no_grad() - def step(self, closure=None): - group = self.param_groups[0] - model = group["model"] - lr = group["lr"] - k = group["k"] - U, V, W = model.U, model.V, model.W # each of shape (hidden_size, group_size) - g_U, g_V, g_W = ( - U.grad, - V.grad, - W.grad, - ) # each of shape (hidden_size, group_size) - if g_U is None or g_V is None or g_W is None: - return - # per-neuron norms - u2 = (U**2).sum(dim=1) # shape: (hidden_size,): nb of hidden neurons. - v2 = (V**2).sum(dim=1) # shape: (hidden_size,): nb of hidden neurons. - w2 = (W**2).sum(dim=1) # shape: (hidden_size,): nb of hidden neurons. - theta_norm = torch.sqrt( - u2 + v2 + w2 + 1e-12 - ) # shape: (hidden_size,): nb of hidden neurons. - # scale = ||theta_i||^(1 - k) - scale = theta_norm.pow(1 - k) # shape: (hidden_size,): nb of hidden neurons. - # scale each neuron's grads - g_U.mul_(scale.view(-1, 1)) - g_V.mul_(scale.view(-1, 1)) - g_W.mul_(scale.view(-1, 1)) - # SGD update - U.add_(g_U, alpha=-lr) - V.add_(g_V, alpha=-lr) - W.add_(g_W, alpha=-lr) diff --git a/pyproject.toml b/pyproject.toml index eec87ae..137106c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["gagf", "group_agf", "default_config"] +known-first-party = ["src", "group_agf", "default_config"] known-third-party = ["wandb"] [tool.pytest.ini_options] diff --git a/gagf/rnns/README.md b/src/README.md similarity index 100% rename from gagf/rnns/README.md rename to src/README.md diff --git a/gagf/__init__.py b/src/__init__.py similarity index 100% rename from gagf/__init__.py rename to src/__init__.py diff --git a/gagf/rnns/config.yaml b/src/config.yaml similarity index 77% rename from gagf/rnns/config.yaml rename to src/config.yaml index f13f51b..9a309c1 100644 --- a/gagf/rnns/config.yaml +++ b/src/config.yaml @@ -1,13 +1,16 @@ # ============================================================================ # Base Configuration File # ============================================================================ -# This config supports both 1D and 2D modular addition tasks with either -# QuadraticRNN or SequentialMLP models. +# This config supports various group tasks with either QuadraticRNN or +# SequentialMLP models. # # Quick Setup Guide: # ------------------ -# 1D Task (C_p): Set dimension=1, specify p -# 2D Task (C_p1 x C_p2): Set dimension=2, specify p1 and p2 +# Cyclic group C_p: Set group_name='cn', specify p +# Product group C_p1xC_p2: Set group_name='cnxcn', specify p1 and p2 +# Dihedral group D_n: Set group_name='dihedral', specify group_n +# Octahedral group: Set group_name='octahedral' +# Icosahedral (A5): Set group_name='A5' # QuadraticRNN: Set model_type='QuadraticRNN' # SequentialMLP: Set model_type='SequentialMLP' # ============================================================================ @@ -15,14 +18,18 @@ # Data Configuration # ------------------ data: - # Dimension: 1 for C_p (cyclic group), 2 for C_p1 x C_p2 (product group), 'D3' for Dihedral D3 - dimension: D3 # 1 | 2 | 'D3' + # Group name: 'cn' | 'cnxcn' | 'dihedral' | 'octahedral' | 'A5' + group_name: dihedral + + # Group order parameter (for parameterized groups) + # For dihedral: n in D_n (e.g., 3 for D3, 4 for D4) + group_n: 3 # Group Parameters - # For dimension=1: only 'p' is used - # For dimension=2: 'p1' and 'p2' are used - # For dimension='D3': none of p, p1, p2 are used - p: 10 # Cyclic group dimension (1D only) + # For group_name='cn': only 'p' is used + # For group_name='cnxcn': 'p1' and 'p2' are used + # For group_name='dihedral'/'octahedral'/'A5': p, p1, p2 are not used + p: 10 # Cyclic group dimension (cn only) p1: 4 #10 # Height/rows dimension (2D only) p2: 4 # Width/cols dimension (2D only) @@ -32,14 +39,14 @@ data: seed: 5 # Template Generation - # For dimension=1,2: 'mnist' | 'fourier' | 'gaussian' | 'onehot' - # For dimension='D3': 'onehot' | 'custom_fourier' + # For group_name='cn','cnxcn': 'mnist' | 'fourier' | 'gaussian' | 'onehot' + # For group_name='dihedral','octahedral','A5': 'onehot' | 'custom_fourier' template_type: onehot mnist_label: 4 # MNIST digit (0-9), only if template_type='mnist' n_freqs: 1 # Number of Fourier modes, only if template_type='fourier' - # D3 custom_fourier template: powers for each irrep's Fourier coefficient - # D3 has 3 irreps with dimensions [1, 1, 2], so powers should have 3 values + # custom_fourier template: powers for each irrep's Fourier coefficient + # Example for D3: 3 irreps with dimensions [1, 1, 2], so powers should have 3 values # Large ratio between powers = clearer staircase steps powers: - 0.0 @@ -86,7 +93,7 @@ training: learning_rate: 0.00008 # Base learning rate # Recommended settings: # - adam: 1e-3 to 1e-4 - # - per_neuron (SequentialMLP): 1.0 (or 0.01 for D3) + # - per_neuron (SequentialMLP): 1.0 (or 0.01 for dihedral) # - hybrid: see scaling_factor betas: diff --git a/src/config_a5.yaml b/src/config_a5.yaml new file mode 100644 index 0000000..306181e --- /dev/null +++ b/src/config_a5.yaml @@ -0,0 +1,46 @@ +# ============================================================================ +# Configuration: Icosahedral Group (A5) +# ============================================================================ +# A5 (Icosahedral) group has order 60 and 5 irreps with dimensions [1, 3, 5, 3, 4] + +data: + group_name: A5 + k: 2 + batch_size: 128 + seed: 10 + template_type: custom_fourier + mode: sampled + num_samples: 1000 + + # custom_fourier powers (one per irrep) + # A5 irreps: [1, 3, 5, 3, 4] + powers: [0.0, 1300.0, 0.0, 2000.0, 0.0] + +model: + model_type: TwoLayerNet + hidden_dim: 1200 # hidden_factor=30 * group_size=60 + init_scale: 0.001 # 1e-3 from default_config + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 1000 + num_steps: 100 + optimizer: per_neuron + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 100 + save_param_interval: 10 + reduction_threshold: null + +device: cuda:0 + +analysis: + checkpoints: [0.0, 1.0] diff --git a/src/config_c10.yaml b/src/config_c10.yaml new file mode 100644 index 0000000..e6da970 --- /dev/null +++ b/src/config_c10.yaml @@ -0,0 +1,40 @@ +# ============================================================================ +# Configuration: Cyclic Group C_10 +# ============================================================================ + +data: + group_name: cn + p: 10 + k: 3 + batch_size: 128 + seed: 10 + template_type: onehot + mode: exhaustive + num_samples: 1000 + +model: + model_type: SequentialMLP + hidden_dim: 300 # hidden_factor=30 * group_size=10 + init_scale: 0.01 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 1000 + num_steps: 100 + optimizer: per_neuron + learning_rate: 0.01 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 100 + save_param_interval: 10 + reduction_threshold: null + +device: cuda:0 + +analysis: + checkpoints: [0.0, 1.0] diff --git a/src/config_c4x4.yaml b/src/config_c4x4.yaml new file mode 100644 index 0000000..8c5d5e2 --- /dev/null +++ b/src/config_c4x4.yaml @@ -0,0 +1,42 @@ +# ============================================================================ +# Configuration: Product Group C_4 x C_4 +# ============================================================================ + +data: + group_name: cnxcn + p1: 4 + p2: 4 + k: 3 + batch_size: 128 + seed: 10 + template_type: fourier + n_freqs: 1 + mode: exhaustive + num_samples: 1000 + +model: + model_type: SequentialMLP + hidden_dim: 480 # hidden_factor=30 * group_size=16 + init_scale: 0.01 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 1000 + num_steps: 100 + optimizer: per_neuron + learning_rate: 0.01 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 100 + save_param_interval: 10 + reduction_threshold: null + +device: cuda:0 + +analysis: + checkpoints: [0.0, 1.0] diff --git a/src/config_d3.yaml b/src/config_d3.yaml new file mode 100644 index 0000000..0122d35 --- /dev/null +++ b/src/config_d3.yaml @@ -0,0 +1,48 @@ +# ============================================================================ +# Configuration: Dihedral Group D3 +# ============================================================================ +# D3 has 3 irreps with dimensions [1, 1, 2] + +data: + group_name: dihedral + group_n: 3 + k: 2 + batch_size: 128 + seed: 10 + template_type: custom_fourier + mode: exhaustive + num_samples: 1000 + + # custom_fourier powers (one per irrep) + # D3 irreps: [1, 1, 2] + # Well-separated powers for clean staircase learning + powers: [0.0, 30.0, 3000.0] + +model: + model_type: TwoLayerNet + hidden_dim: 180 # hidden_factor=30 * group_size=6 + init_scale: 0.001 + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2000 + num_steps: 100 + optimizer: per_neuron + learning_rate: 0.01 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 100 + save_param_interval: 1 + reduction_threshold: null + +device: cuda:0 + +analysis: + checkpoints: [0.0, 1.0] diff --git a/src/config_octahedral.yaml b/src/config_octahedral.yaml new file mode 100644 index 0000000..375dbec --- /dev/null +++ b/src/config_octahedral.yaml @@ -0,0 +1,46 @@ +# ============================================================================ +# Configuration: Octahedral Group +# ============================================================================ +# Octahedral group has order 24 and 5 irreps with dimensions [1, 3, 3, 2, 1] + +data: + group_name: octahedral + k: 2 + batch_size: 128 + seed: 10 + template_type: custom_fourier + mode: sampled + num_samples: 1000 + + # custom_fourier powers (one per irrep) + # Octahedral irreps: [1, 3, 3, 2, 1] + powers: [0.0, 0.0, 0.0, 300.0, 500.0] + +model: + model_type: TwoLayerNet + hidden_dim: 1200 # hidden_factor=50 * group_size=24 + init_scale: 0.001 + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 1000 + num_steps: 100 + optimizer: per_neuron + learning_rate: 0.01 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 200 + save_param_interval: 10 + reduction_threshold: null + +device: cuda:0 + +analysis: + checkpoints: [0.0, 1.0] diff --git a/gagf/rnns/datamodule.py b/src/datamodule.py similarity index 90% rename from gagf/rnns/datamodule.py rename to src/datamodule.py index b76b9ac..d15384f 100644 --- a/gagf/rnns/datamodule.py +++ b/src/datamodule.py @@ -341,21 +341,23 @@ def build_modular_addition_sequence_dataset_D3( mode: str = "sampled", num_samples: int = 65536, return_all_outputs: bool = False, + dihedral_n: int = 3, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ - Build D3 (dihedral group) composition dataset for sequence length k. + Build dihedral group composition dataset for sequence length k. - Uses the regular representation of D3 to transform the template. + Uses the regular representation of D_n to transform the template. For a sequence of k group elements (g1, g2, ..., gk), we compute: - X[i, t, :] = regular_rep(g_t) @ template (template transformed by g_t) - Y[i, :] = regular_rep(g1 * g2 * ... * gk) @ template (template transformed by composition) Args: - template: (group_order,) template array, where group_order = 6 for D3 + template: (group_order,) template array, where group_order = 2*n for D_n k: sequence length (number of group elements to compose) mode: "sampled" or "exhaustive" num_samples: number of samples for "sampled" mode return_all_outputs: if True, return intermediate outputs after each composition + dihedral_n: the n parameter for the dihedral group D_n (default 3 for D3) Returns: X: (N, k, group_order) input sequences @@ -364,17 +366,17 @@ def build_modular_addition_sequence_dataset_D3( """ from escnn.group import DihedralGroup - # Create D3 group (dihedral group of order 6) - D3 = DihedralGroup(N=3) - group_order = D3.order() # = 6 + # Create D_n group (dihedral group of order 2*n) + dihedral_group = DihedralGroup(N=dihedral_n) + group_order = dihedral_group.order() # = 2 * dihedral_n assert template.shape == (group_order,), ( f"template must be ({group_order},), got {template.shape}" ) # Get regular representation and list of elements - regular_rep = D3.representations["regular"] - elements = list(D3.elements) + regular_rep = dihedral_group.representations["regular"] + elements = list(dihedral_group.elements) n_elements = len(elements) # = 6 # Pre-compute representation matrices for all elements @@ -432,6 +434,87 @@ def build_modular_addition_sequence_dataset_D3( return X, Y, sequence +def build_modular_addition_sequence_dataset_generic( + template: np.ndarray, + k: int, + group, + mode: str = "sampled", + num_samples: int = 65536, + return_all_outputs: bool = False, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Build generic group composition dataset for sequence length k. + + Works with any escnn group that has a regular representation. + For a sequence of k group elements (g1, g2, ..., gk), we compute: + - X[i, t, :] = regular_rep(g_t) @ template (template transformed by g_t) + - Y[i, :] = regular_rep(g1 * g2 * ... * gk) @ template (template transformed by composition) + + Args: + template: (group_order,) template array + k: sequence length (number of group elements to compose) + group: escnn group object (e.g., Octahedral(), Icosahedral()) + mode: "sampled" or "exhaustive" + num_samples: number of samples for "sampled" mode + return_all_outputs: if True, return intermediate outputs after each composition + + Returns: + X: (N, k, group_order) input sequences + Y: (N, group_order) or (N, k-1, group_order) target outputs + sequence: (N, k) integer indices of group elements per token + """ + group_order = group.order() + + assert template.shape == (group_order,), ( + f"template must be ({group_order},), got {template.shape}" + ) + + # Get regular representation and list of elements + regular_rep = group.representations["regular"] + elements = list(group.elements) + n_elements = len(elements) + + # Pre-compute representation matrices for all elements + rep_matrices = np.array([regular_rep(g) for g in elements]) + + if mode == "exhaustive": + total = n_elements**k + if total > 1_000_000: + raise ValueError(f"n_elements^k = {total} is huge; use mode='sampled' instead.") + N = total + + # Generate all possible sequences of k element indices + sequence = np.zeros((N, k), dtype=np.int64) + for idx in range(N): + for t in range(k): + sequence[idx, t] = (idx // (n_elements**t)) % n_elements + else: + N = int(num_samples) + sequence = np.random.randint(0, n_elements, size=(N, k), dtype=np.int64) + + # Initialize output arrays + X = np.zeros((N, k, group_order), dtype=np.float32) + Y = np.zeros((N, k, group_order), dtype=np.float32) + + for i in range(N): + cumulative_rep = np.eye(group_order) + + for t in range(k): + elem_idx = sequence[i, t] + g_rep = rep_matrices[elem_idx] + + X[i, t, :] = g_rep @ template + cumulative_rep = g_rep @ cumulative_rep + Y[i, t, :] = cumulative_rep @ template + + if not return_all_outputs: + Y = Y[:, -1, :] + else: + Y = Y[:, 1:, :] + + return X, Y, sequence + + def sequence_to_paths_xy(sequence_xy: np.ndarray, p1: int, p2: int) -> np.ndarray: """ Convert a sequence of group elements (ax_t, ay_t) into cumulative positions diff --git a/group_agf/binary_action_learning/datasets.py b/src/datasets.py similarity index 99% rename from group_agf/binary_action_learning/datasets.py rename to src/datasets.py index e7181dd..a98df0d 100644 --- a/group_agf/binary_action_learning/datasets.py +++ b/src/datasets.py @@ -1,7 +1,7 @@ import numpy as np import torch -import group_agf.binary_action_learning.templates as templates +import src.templates as templates def load_dataset(config): diff --git a/group_agf/binary_action_learning/group_fourier_transform.py b/src/group_fourier_transform.py similarity index 100% rename from group_agf/binary_action_learning/group_fourier_transform.py rename to src/group_fourier_transform.py diff --git a/gagf/rnns/main.py b/src/main.py similarity index 73% rename from gagf/rnns/main.py rename to src/main.py index 88f25ec..0d9aaa3 100644 --- a/gagf/rnns/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from datetime import datetime from pathlib import Path +import matplotlib import matplotlib.pyplot as plt import numpy as np import torch @@ -12,7 +13,7 @@ from torch import nn, optim from torch.utils.data import DataLoader -from gagf.rnns.datamodule import ( +from src.datamodule import ( generate_fourier_template_1d, generate_gaussian_template_1d, generate_onehot_template_1d, @@ -20,9 +21,9 @@ mnist_template_1d, mnist_template_2d, ) -from gagf.rnns.model import QuadraticRNN, SequentialMLP -from gagf.rnns.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD -from gagf.rnns.utils import ( +from src.model import QuadraticRNN, SequentialMLP, TwoLayerNet +from src.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD +from src.utils import ( plot_2d_signal, plot_model_predictions_over_time, plot_model_predictions_over_time_1d, @@ -32,6 +33,9 @@ topk_template_freqs, ) +matplotlib.rcParams["pdf.fonttype"] = 42 # TrueType fonts for PDF viewer compatibility +matplotlib.rcParams["ps.fonttype"] = 42 + def load_config(config_path: str) -> dict: """Load configuration from YAML file.""" @@ -137,8 +141,8 @@ def produce_plots_2d( print("\n=== Generating Analysis Plots ===") ### ----- COMPUTE X-AXIS VALUES ----- ### - dimension = config["data"]["dimension"] - if dimension == 1: + group_name = config["data"]["group_name"] + if group_name == "cn": p_flat = config["data"]["p"] else: p_flat = config["data"]["p1"] * config["data"]["p2"] @@ -168,7 +172,7 @@ def produce_plots_2d( ### ----- GENERATE EVALUATION DATA ----- ### print("Generating evaluation data for visualization...") - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_2d + from src.datamodule import build_modular_addition_sequence_dataset_2d X_seq_2d, Y_seq_2d, _ = build_modular_addition_sequence_dataset_2d( config["data"]["p1"], @@ -246,55 +250,11 @@ def produce_plots_2d( show=False, ) - # ### ----- PLOT POWER SPECTRUM ANALYSIS ----- ### - # print("Analyzing power spectrum of predictions over training...") - # plot_prediction_power_spectrum_over_time( - # model, - # param_hist, - # X_seq_2d_t, - # Y_seq_2d_t, - # template_2d, - # config['data']['p1'], - # config['data']['p2'], - # loss_history=train_loss_hist, - # param_save_indices=param_save_indices, - # num_freqs_to_track=10, - # checkpoint_indices=checkpoint_indices, - # num_samples=100, - # save_path=os.path.join(run_dir, "power_spectrum_analysis.pdf"), - # show=False - # ) - ### ----- PLOT FOURIER MODES REFERENCE ----- ### print("Creating Fourier modes reference...") tracked_freqs = topk_template_freqs(template_2d, K=10) colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) - # plot_fourier_modes_reference( - # tracked_freqs, - # colors, - # config['data']['p1'], - # config['data']['p2'], - # save_path=os.path.join(run_dir, "fourier_modes_reference.pdf"), - # save_individual=True, - # individual_dir=os.path.join(run_dir, "fourier_modes"), - # show=False - # ) - - # ### ----- PLOT W_OUT NEURON SPECIALIZATION ----- ### - # print("Visualizing W_out neuron specialization...") - # plot_wout_neuron_specialization( - # param_hist, - # tracked_freqs, - # colors, - # config['data']['p1'], - # config['data']['p2'], - # steps=checkpoint_indices, - # dead_thresh_l2=0.25, - # save_dir=run_dir, - # show=False - # ) - ### ----- PLOT W_MIX FREQUENCY STRUCTURE (QuadraticRNN only) ----- ### model_type = config["model"]["model_type"] if model_type == "QuadraticRNN": @@ -371,7 +331,7 @@ def produce_plots_1d( ### ----- GENERATE EVALUATION DATA ----- ### print("Generating evaluation data for visualization...") - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_1d + from src.datamodule import build_modular_addition_sequence_dataset_1d X_seq_1d, Y_seq_1d, _ = build_modular_addition_sequence_dataset_1d( config["data"]["p"], @@ -455,26 +415,10 @@ def produce_plots_1d( show=False, ) - # ### ----- PLOT W_OUT NEURON SPECIALIZATION ----- ### - # print("Visualizing W_out neuron specialization...") - # tracked_freqs = topk_template_freqs_1d(template_1d, K=min(10, p // 4)) - # colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) - - # plot_wout_neuron_specialization_1d( - # param_hist, - # tracked_freqs, - # colors, - # p, - # steps=checkpoint_indices, - # dead_thresh_l2=0.25, - # save_dir=run_dir, - # show=False - # ) - print("\n✓ All 1D plots generated successfully!") -def plot_model_predictions_over_time_D3( +def plot_model_predictions_over_time_group( model, param_hist, X_eval, @@ -483,19 +427,21 @@ def plot_model_predictions_over_time_D3( checkpoint_indices: list, save_path: str = None, num_samples: int = 5, + group_label: str = "Group", ): """ - Plot model predictions vs targets at different training checkpoints for D3. + Plot model predictions vs targets at different training checkpoints. Args: model: Trained model param_hist: List of parameter snapshots X_eval: Input evaluation tensor (N, k, group_order) Y_eval: Target evaluation tensor (N, group_order) - group_order: Order of D3 group (6) + group_order: Order of the group checkpoint_indices: Indices into param_hist to visualize save_path: Path to save the plot num_samples: Number of samples to show + group_label: Human-readable label for the group (used in plot title) """ n_checkpoints = len(checkpoint_indices) @@ -541,7 +487,7 @@ def plot_model_predictions_over_time_D3( ax.set_xticks(x_axis) ax.grid(True, alpha=0.3) - plt.suptitle("D3 Model Predictions vs Targets Over Training", fontsize=14) + plt.suptitle(f"{group_label} Predictions vs Targets Over Training", fontsize=14) plt.tight_layout() if save_path: @@ -549,22 +495,24 @@ def plot_model_predictions_over_time_D3( plt.close() -def plot_power_spectrum_over_time_D3( +def plot_power_spectrum_over_time_group( model, param_hist, param_save_indices, X_eval, template: np.ndarray, - D3, + group, k: int, optimizer: str, init_scale: float, save_path: str = None, - num_samples_for_power: int = 100, - num_checkpoints_to_sample: int = 50, + group_label: str = "Group", ): """ - Plot power spectrum of model outputs vs template power spectrum over training for D3. + Plot power spectrum of model outputs vs template power spectrum over training. + + Uses GroupPower from src/power.py for template power and model_power_over_time + for model output power over training checkpoints. Args: model: Trained model @@ -572,66 +520,30 @@ def plot_power_spectrum_over_time_D3( param_save_indices: List mapping param_hist index to epoch number X_eval: Input evaluation tensor template: Template array (group_order,) - D3: DihedralGroup object from escnn + group: escnn group object k: Sequence length optimizer: Optimizer name (e.g., 'per_neuron', 'adam') init_scale: Initialization scale save_path: Path to save the plot - num_samples_for_power: Number of samples to average power over - num_checkpoints_to_sample: Number of checkpoints to sample for the evolution plot + group_label: Human-readable label for the group (used in plot titles) """ - from group_agf.binary_action_learning.group_fourier_transform import compute_group_fourier_coef + from src.power import GroupPower, model_power_over_time - group_order = D3.order() - irreps = D3.irreps() + group_name = "group" # generic group for model_power_over_time dispatch + irreps = group.irreps() n_irreps = len(irreps) - # Compute template power spectrum - template_power = np.zeros(n_irreps) - for i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, template, irrep) - template_power[i] = irrep.size * np.trace(fourier_coef.conj().T @ fourier_coef) - template_power = template_power / group_order + # Compute template power spectrum using GroupPower + template_power_obj = GroupPower(template, group=group) + template_power = template_power_obj.power print(f" Template power spectrum: {template_power}") print(" (These are dim^2 * diag_value^2 / |G| for each irrep)") - # Sample checkpoints uniformly for evolution plot - total_checkpoints = len(param_hist) - if total_checkpoints <= num_checkpoints_to_sample: - sampled_ckpt_indices = list(range(total_checkpoints)) - else: - sampled_ckpt_indices = np.linspace( - 0, total_checkpoints - 1, num_checkpoints_to_sample, dtype=int - ).tolist() - - # Get corresponding epoch numbers - epoch_numbers = [param_save_indices[i] for i in sampled_ckpt_indices] - - # Compute model output power at each sampled checkpoint - n_sampled = len(sampled_ckpt_indices) - model_powers = np.zeros((n_sampled, n_irreps)) - - X_subset = X_eval[:num_samples_for_power] - - for i, ckpt_idx in enumerate(sampled_ckpt_indices): - model.load_state_dict(param_hist[ckpt_idx]) - model.eval() - - with torch.no_grad(): - outputs = model(X_subset) - outputs_np = outputs.cpu().numpy() - - # Average power over all samples - powers = np.zeros((len(outputs_np), n_irreps)) - for sample_i, output in enumerate(outputs_np): - for irrep_i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(D3, output, irrep) - powers[sample_i, irrep_i] = irrep.size * np.trace( - fourier_coef.conj().T @ fourier_coef - ) - powers = powers / group_order - model_powers[i] = np.mean(powers, axis=0) + # Compute model output power over training using model_power_over_time + model_powers, steps = model_power_over_time(group_name, model, param_hist, X_eval, group=group) + # Map step indices to epoch numbers + epoch_numbers = [param_save_indices[min(s, len(param_save_indices) - 1)] for s in steps] # Create 3 subplots: linear, log-x, log-log fig, axes = plt.subplots(1, 3, figsize=(18, 5)) @@ -712,7 +624,7 @@ def plot_power_spectrum_over_time_D3( # Overall title fig.suptitle( - f"D3 Power Evolution Over Training (k={k}, {optimizer}, init={init_scale:.0e})", + f"{group_label} Power Evolution Over Training (k={k}, {optimizer}, init={init_scale:.0e})", fontsize=14, fontweight="bold", ) @@ -724,41 +636,53 @@ def plot_power_spectrum_over_time_D3( plt.close() -def produce_plots_D3( +def produce_plots_group( run_dir: Path, config: dict, model, param_hist, param_save_indices, train_loss_hist, - template_D3: np.ndarray, + template: np.ndarray, device: str = "cpu", + group=None, ): """ - Generate all analysis plots after training (D3 version). + Generate all analysis plots after training for any escnn group. Args: run_dir: Directory to save plots - config: Configuration dictionary (must have dimension='D3') - model: Trained model (QuadraticRNN or SequentialMLP) + config: Configuration dictionary + model: Trained model param_hist: List of parameter snapshots param_save_indices: Indices where params were saved train_loss_hist: Training loss history - template_D3: 1D template array of shape (group_order,) where group_order=6 for D3 + template: 1D template array of shape (group_order,) device: Device string ('cpu' or 'cuda') + group: escnn group object (required) """ - print("\n=== Generating Analysis Plots (D3) ===") + group_name = config["data"]["group_name"] + + # Build a human-readable label for plot titles + if group_name == "dihedral": + n = config["data"].get("group_n", 3) + group_label = f"D{n} (Dihedral, order {group.order()})" + elif group_name == "octahedral": + group_label = f"Octahedral (order {group.order()})" + elif group_name == "A5": + group_label = f"A5 / Icosahedral (order {group.order()})" + else: + group_label = group_name - from escnn.group import DihedralGroup + print(f"\n=== Generating Analysis Plots ({group_label}) ===") - D3 = DihedralGroup(N=3) - group_order = D3.order() # = 6 + group_order = group.order() k = config["data"]["k"] batch_size = config["data"]["batch_size"] training_mode = config["training"]["mode"] - # Total data space size for D3 with k compositions + # Total data space size with k compositions total_space_size = group_order**k # Calculate x-axis values @@ -783,7 +707,7 @@ def produce_plots_D3( print(f" ✓ Saved {samples_seen_path}") print(f" ✓ Saved {fraction_path}") - print(f"\nD3 group order: {group_order}") + print(f"\n{group_name} group order: {group_order}") print(f"Sequence length k: {k}") print(f"Total data space: {total_space_size:,} sequences") if len(samples_seen) > 0: @@ -791,17 +715,35 @@ def produce_plots_D3( ### ----- GENERATE EVALUATION DATA ----- ### print("\nGenerating evaluation data for visualization...") - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_D3 + model_type = config["model"]["model_type"] + + if model_type == "TwoLayerNet": + # TwoLayerNet expects flattened binary pair input: (N, 2*group_size) + from src.datasets import group_dataset, move_dataset_to_device_and_flatten + + X_raw, Y_raw = group_dataset(group, template) + X_eval_t, Y_eval_t, device = move_dataset_to_device_and_flatten(X_raw, Y_raw, device=device) + # Optionally subsample for visualization + n_eval = min(len(X_eval_t), 1000) + if n_eval < len(X_eval_t): + indices = np.random.choice(len(X_eval_t), size=n_eval, replace=False) + X_eval_t = X_eval_t[indices] + Y_eval_t = Y_eval_t[indices] + else: + # Sequence models use the generic sequence dataset + from src.datamodule import build_modular_addition_sequence_dataset_generic + + X_eval, Y_eval, _ = build_modular_addition_sequence_dataset_generic( + template, + k, + group=group, + mode="sampled", + num_samples=min(config["data"]["num_samples"], 1000), + return_all_outputs=config["model"]["return_all_outputs"], + ) + X_eval_t = torch.tensor(X_eval, dtype=torch.float32, device=device) + Y_eval_t = torch.tensor(Y_eval, dtype=torch.float32, device=device) - X_eval, Y_eval, _ = build_modular_addition_sequence_dataset_D3( - template_D3, - k, - mode="sampled", - num_samples=min(config["data"]["num_samples"], 1000), - return_all_outputs=config["model"]["return_all_outputs"], - ) - X_eval_t = torch.tensor(X_eval, dtype=torch.float32, device=device) - Y_eval_t = torch.tensor(Y_eval, dtype=torch.float32, device=device) print(f" Generated {X_eval_t.shape[0]} samples for visualization") ### ----- COMPUTE CHECKPOINT INDICES ----- ### @@ -831,7 +773,7 @@ def produce_plots_D3( ax.set_title(title) ax.grid(True, alpha=0.3) - plt.suptitle(f"D3 Group Composition (k={k})", fontsize=14) + plt.suptitle(f"{group_label} Composition (k={k})", fontsize=14) plt.tight_layout() training_loss_path = os.path.join(run_dir, "training_loss.pdf") plt.savefig(training_loss_path, bbox_inches="tight", dpi=150) @@ -840,7 +782,7 @@ def produce_plots_D3( ### ----- PLOT MODEL PREDICTIONS OVER TIME ----- ### print("\nPlotting model predictions over time...") - plot_model_predictions_over_time_D3( + plot_model_predictions_over_time_group( model=model, param_hist=param_hist, X_eval=X_eval_t, @@ -848,6 +790,7 @@ def produce_plots_D3( group_order=group_order, checkpoint_indices=checkpoint_indices, save_path=os.path.join(run_dir, "predictions_over_time.pdf"), + group_label=group_label, ) print(f" ✓ Saved {os.path.join(run_dir, 'predictions_over_time.pdf')}") @@ -855,21 +798,22 @@ def produce_plots_D3( print("\nPlotting power spectrum over time...") optimizer = config["training"]["optimizer"] init_scale = config["model"]["init_scale"] - plot_power_spectrum_over_time_D3( + plot_power_spectrum_over_time_group( model=model, param_hist=param_hist, param_save_indices=param_save_indices, X_eval=X_eval_t, - template=template_D3, - D3=D3, + template=template, + group=group, k=k, optimizer=optimizer, init_scale=init_scale, save_path=os.path.join(run_dir, "power_spectrum_analysis.pdf"), + group_label=group_label, ) print(f" ✓ Saved {os.path.join(run_dir, 'power_spectrum_analysis.pdf')}") - print("\n✓ All D3 plots generated successfully!") + print(f"\n✓ All {group_label} plots generated successfully!") def train_single_run(config: dict, run_dir: Path = None) -> dict: @@ -900,10 +844,11 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ### ----- GENERATE DATA ----- ### print("Generating data...") - dimension = config["data"]["dimension"] + group_name = config["data"]["group_name"] + group_n = config["data"].get("group_n") # For dihedral groups (D3, D4, etc.) template_type = config["data"]["template_type"] - if dimension == 1: + if group_name == "cn": # 1D template generation p = config["data"]["p"] p_flat = p @@ -938,7 +883,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: fig.savefig(os.path.join(run_dir, "template.pdf"), bbox_inches="tight", dpi=150) print(" ✓ Saved template") - elif dimension == 2: + elif group_name == "cnxcn": # 2D template generation p1 = config["data"]["p1"] p2 = config["data"]["p2"] @@ -952,7 +897,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: p1, p2, n_freqs=n_freqs, seed=config["data"]["seed"] ) else: - raise ValueError(f"Unknown template_type for 2D: {template_type}") + raise ValueError(f"Unknown template_type for cnxcn: {template_type}") template_2d = template_2d - np.mean(template_2d) template = template_2d # For consistency in code below @@ -962,44 +907,51 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: fig, ax = plot_2d_signal(template_2d, title="Template", cmap="gray") fig.savefig(os.path.join(run_dir, "template.pdf"), bbox_inches="tight", dpi=150) print(" ✓ Saved template") - elif dimension == "D3": - from escnn.group import DihedralGroup - - from group_agf.binary_action_learning.group_fourier_transform import ( + elif group_name in ("dihedral", "octahedral", "A5"): + from src.group_fourier_transform import ( compute_group_inverse_fourier_transform, ) - D3 = DihedralGroup(N=3) # D3 = dihedral group of order 6 (3 rotations * 2 for reflections) - group_order = D3.order() # = 6 - p_flat = group_order # For D3, the "p" is the group order + # Construct the escnn group object + if group_name == "dihedral": + from escnn.group import DihedralGroup + + n = group_n if group_n is not None else 3 + group = DihedralGroup(N=n) + group_label = f"Dihedral D{n}" + elif group_name == "octahedral": + from escnn.group import Octahedral + + group = Octahedral() + group_label = "Octahedral" + elif group_name == "A5": + from escnn.group import Icosahedral - print(f"D3 group order: {group_order}") - print(f"D3 irreps: {[irrep.size for irrep in D3.irreps()]} (dimensions)") + group = Icosahedral() + group_label = "Icosahedral (A5)" + group_order = group.order() + p_flat = group_order + + print(f"{group_label} group order: {group_order}") + print(f"{group_label} irreps: {[irrep.size for irrep in group.irreps()]} (dimensions)") + + # Generate template if template_type == "onehot": - # Generate one-hot template of length group_order - # This creates a template with a spike at position 1 - template_d3 = np.zeros(group_order, dtype=np.float32) - template_d3[1] = 10.0 - template_d3 = template_d3 - np.mean(template_d3) + template = np.zeros(group_order, dtype=np.float32) + template[1] = 10.0 + template = template - np.mean(template) print("Template type: onehot") elif template_type == "custom_fourier": - # Generate template from Fourier coefficients for each irrep - # powers specifies the DESIRED POWER SPECTRUM values (not diagonal values) - # We convert powers to Fourier coefficient diagonal values using: - # diag_value = sqrt(group_size * power / dim^2) - # This is because: power = dim^2 * diag_value^2 / group_size powers = config["data"]["powers"] - irreps = D3.irreps() + irreps = group.irreps() irrep_dims = [ir.size for ir in irreps] assert len(powers) == len(irreps), ( f"powers must have {len(irreps)} values (one per irrep), got {len(powers)}" ) - # Convert powers to Fourier coefficient diagonal values - # (same formula as in binary_action_learning/main.py) fourier_coef_diag_values = [ np.sqrt(group_order * p / dim**2) if p > 0 else 0.0 for p, dim in zip(powers, irrep_dims) @@ -1009,7 +961,6 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print(f"Desired powers (per irrep): {powers}") print(f"Fourier coef diagonal values: {fourier_coef_diag_values}") - # Build spectrum: list of diagonal matrices, one per irrep spectrum = [] for i, irrep in enumerate(irreps): diag_val = fourier_coef_diag_values[i] @@ -1021,34 +972,36 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) spectrum.append(mat) - # Generate template via inverse group Fourier transform - template_d3 = compute_group_inverse_fourier_transform(D3, spectrum) - template_d3 = template_d3 - np.mean(template_d3) - template_d3 = template_d3.astype(np.float32) + template = compute_group_inverse_fourier_transform(group, spectrum) + template = template - np.mean(template) + template = template.astype(np.float32) else: raise ValueError( - f"Unknown template_type for D3: {template_type}. Must be 'onehot' or 'custom_fourier'" + f"Unknown template_type for {group_name}: {template_type}. " + "Must be 'onehot' or 'custom_fourier'" ) - template = template_d3 # For consistency in code below print(f"Template shape: {template.shape}") - # Visualize D3 template + # Visualize template print("Visualizing template...") - fig, ax = plt.subplots(figsize=(8, 4)) - ax.bar(range(group_order), template_d3) + fig, ax = plt.subplots(figsize=(max(8, group_order // 5), 4)) + ax.bar(range(group_order), template) ax.set_xlabel("Group element index") ax.set_ylabel("Value") - title = f"D3 Template (order={group_order}, type={template_type})" + title = f"{group_label} Template (order={group_order}, type={template_type})" if template_type == "custom_fourier": title += f"\npowers={powers}" ax.set_title(title) - ax.set_xticks(range(group_order)) + if group_order <= 30: + ax.set_xticks(range(group_order)) fig.savefig(os.path.join(run_dir, "template.pdf"), bbox_inches="tight", dpi=150) plt.close(fig) print(" ✓ Saved template") else: - raise ValueError(f"dimension must be 1 or 2, got {dimension}") + raise ValueError( + f"group_name must be 'cn', 'cnxcn', 'dihedral', 'octahedral', or 'A5', got {group_name}" + ) ### ----- SETUP TRAINING ----- ### print("Setting up model and training...") @@ -1078,9 +1031,20 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: init_scale=config["model"]["init_scale"], return_all_outputs=config["model"]["return_all_outputs"], ).to(device) + elif model_type == "TwoLayerNet": + hidden_dim = config["model"]["hidden_dim"] + nonlinearity = config["model"].get("nonlinearity", "square") + output_scale = config["model"].get("output_scale", 1.0) + rnn_2d = TwoLayerNet( + group_size=p_flat, + hidden_size=hidden_dim, + nonlinearity=nonlinearity, + init_scale=config["model"]["init_scale"], + output_scale=output_scale, + ).to(device) else: raise ValueError( - f"Invalid model_type: {model_type}. Must be 'QuadraticRNN' or 'SequentialMLP'" + f"Invalid model_type: {model_type}. Must be 'QuadraticRNN', 'SequentialMLP', or 'TwoLayerNet'" ) criterion = nn.MSELoss() @@ -1146,8 +1110,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: if training_mode == "online": print("Using ONLINE data generation...") - if dimension == 1: - from gagf.rnns.datamodule import OnlineModularAdditionDataset1D + if group_name == "cn": + from src.datamodule import OnlineModularAdditionDataset1D # Training dataset train_dataset = OnlineModularAdditionDataset1D( @@ -1168,8 +1132,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: device=device, return_all_outputs=config["model"]["return_all_outputs"], ) - elif dimension == 2: - from gagf.rnns.datamodule import OnlineModularAdditionDataset2D + elif group_name == "cnxcn": + from src.datamodule import OnlineModularAdditionDataset2D # Training dataset train_dataset = OnlineModularAdditionDataset2D( @@ -1192,14 +1156,16 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: device=device, return_all_outputs=config["model"]["return_all_outputs"], ) - elif dimension == "D3": - # Online training for D3 is not yet implemented + elif group_name in ["dihedral", "octahedral", "A5"]: + # Online training for these groups is not yet implemented raise NotImplementedError( - "Online training mode is not yet implemented for D3. " + f"Online training mode is not yet implemented for {group_name}. " "Please use training.mode='offline' in the config." ) else: - raise ValueError(f"dimension must be 1, 2, or 'D3', got {dimension}") + raise ValueError( + f"group_name must be 'cn', 'cnxcn', 'dihedral', 'octahedral', or 'A5', got {group_name}" + ) train_loader = DataLoader(train_dataset, batch_size=None, num_workers=0) val_loader = DataLoader(val_dataset, batch_size=None, num_workers=0) @@ -1211,82 +1177,124 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print("Using OFFLINE pre-generated dataset...") from torch.utils.data import TensorDataset - if dimension == 1: - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_1d - - # Generate training dataset - X_train, Y_train, _ = build_modular_addition_sequence_dataset_1d( - config["data"]["p"], - template_1d, - config["data"]["k"], - mode=config["data"]["mode"], - num_samples=config["data"]["num_samples"], - return_all_outputs=config["model"]["return_all_outputs"], - ) - - # Generate validation dataset - val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_1d( - config["data"]["p"], - template_1d, - config["data"]["k"], - mode="sampled", - num_samples=val_samples, - return_all_outputs=config["model"]["return_all_outputs"], - ) - elif dimension == 2: - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_2d - - # Generate training dataset - X_train, Y_train, _ = build_modular_addition_sequence_dataset_2d( - config["data"]["p1"], - config["data"]["p2"], - template_2d, - config["data"]["k"], - mode=config["data"]["mode"], - num_samples=config["data"]["num_samples"], - return_all_outputs=config["model"]["return_all_outputs"], + if model_type == "TwoLayerNet": + # TwoLayerNet uses binary pair datasets from src/datasets.py + # Data shape: X=(N, 2, group_size) -> flattened to (N, 2*group_size), Y=(N, group_size) + from src.datasets import ( + cn_dataset, + cnxcn_dataset, + group_dataset, + move_dataset_to_device_and_flatten, ) - # Generate validation dataset - val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_2d( - config["data"]["p1"], - config["data"]["p2"], - template_2d, - config["data"]["k"], - mode="sampled", - num_samples=val_samples, - return_all_outputs=config["model"]["return_all_outputs"], - ) - elif dimension == "D3": - from gagf.rnns.datamodule import build_modular_addition_sequence_dataset_D3 - - # Generate training dataset - X_train, Y_train, _ = build_modular_addition_sequence_dataset_D3( - template_d3, - config["data"]["k"], - mode=config["data"]["mode"], - num_samples=config["data"]["num_samples"], - return_all_outputs=config["model"]["return_all_outputs"], - ) + if group_name == "cn": + X_raw, Y_raw = cn_dataset(template) + elif group_name == "cnxcn": + X_raw, Y_raw = cnxcn_dataset(template) + elif group_name in ("dihedral", "octahedral", "A5"): + X_raw, Y_raw = group_dataset(group, template) + else: + raise ValueError(f"Unsupported group_name for TwoLayerNet: {group_name}") + + # Flatten X from (N, 2, group_size) to (N, 2*group_size) and convert to tensors + X_all, Y_all, device = move_dataset_to_device_and_flatten(X_raw, Y_raw, device=device) + + # Apply dataset_fraction if configured + dataset_fraction = config["data"].get("dataset_fraction", 1.0) + if dataset_fraction < 1.0: + N = X_all.shape[0] + n_sample = int(np.ceil(N * dataset_fraction)) + indices = np.random.choice(N, size=n_sample, replace=False) + X_all = X_all[indices] + Y_all = Y_all[indices] + + # Split into train/val (90/10) + N = X_all.shape[0] + n_val = max(1, N // 10) + n_train = N - n_val + X_train_t, X_val_t = X_all[:n_train], X_all[n_train:] + Y_train_t, Y_val_t = Y_all[:n_train], Y_all[n_train:] - # Generate validation dataset - val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_D3( - template_d3, - config["data"]["k"], - mode="sampled", - num_samples=val_samples, - return_all_outputs=config["model"]["return_all_outputs"], - ) else: - raise ValueError(f"dimension must be 1 or 2, got {dimension}") + # Sequence models (QuadraticRNN, SequentialMLP) use sequence datasets + if group_name == "cn": + from src.datamodule import build_modular_addition_sequence_dataset_1d + + # Generate training dataset + X_train, Y_train, _ = build_modular_addition_sequence_dataset_1d( + config["data"]["p"], + template_1d, + config["data"]["k"], + mode=config["data"]["mode"], + num_samples=config["data"]["num_samples"], + return_all_outputs=config["model"]["return_all_outputs"], + ) - X_train_t = torch.tensor(X_train, dtype=torch.float32, device=device) - Y_train_t = torch.tensor(Y_train, dtype=torch.float32, device=device) - X_val_t = torch.tensor(X_val, dtype=torch.float32, device=device) - Y_val_t = torch.tensor(Y_val, dtype=torch.float32, device=device) + # Generate validation dataset + val_samples = max(1000, config["data"]["num_samples"] // 10) + X_val, Y_val, _ = build_modular_addition_sequence_dataset_1d( + config["data"]["p"], + template_1d, + config["data"]["k"], + mode="sampled", + num_samples=val_samples, + return_all_outputs=config["model"]["return_all_outputs"], + ) + elif group_name == "cnxcn": + from src.datamodule import build_modular_addition_sequence_dataset_2d + + # Generate training dataset + X_train, Y_train, _ = build_modular_addition_sequence_dataset_2d( + config["data"]["p1"], + config["data"]["p2"], + template_2d, + config["data"]["k"], + mode=config["data"]["mode"], + num_samples=config["data"]["num_samples"], + return_all_outputs=config["model"]["return_all_outputs"], + ) + + # Generate validation dataset + val_samples = max(1000, config["data"]["num_samples"] // 10) + X_val, Y_val, _ = build_modular_addition_sequence_dataset_2d( + config["data"]["p1"], + config["data"]["p2"], + template_2d, + config["data"]["k"], + mode="sampled", + num_samples=val_samples, + return_all_outputs=config["model"]["return_all_outputs"], + ) + elif group_name in ("dihedral", "octahedral", "A5"): + from src.datamodule import build_modular_addition_sequence_dataset_generic + + X_train, Y_train, _ = build_modular_addition_sequence_dataset_generic( + template, + config["data"]["k"], + group=group, + mode=config["data"]["mode"], + num_samples=config["data"]["num_samples"], + return_all_outputs=config["model"]["return_all_outputs"], + ) + + val_samples = max(1000, config["data"]["num_samples"] // 10) + X_val, Y_val, _ = build_modular_addition_sequence_dataset_generic( + template, + config["data"]["k"], + group=group, + mode="sampled", + num_samples=val_samples, + return_all_outputs=config["model"]["return_all_outputs"], + ) + else: + raise ValueError( + f"group_name must be 'cn', 'cnxcn', 'dihedral', 'octahedral', or 'A5', got {group_name}" + ) + + X_train_t = torch.tensor(X_train, dtype=torch.float32, device=device) + Y_train_t = torch.tensor(Y_train, dtype=torch.float32, device=device) + X_val_t = torch.tensor(X_val, dtype=torch.float32, device=device) + Y_val_t = torch.tensor(Y_val, dtype=torch.float32, device=device) train_dataset = TensorDataset(X_train_t, Y_train_t) val_dataset = TensorDataset(X_val_t, Y_val_t) @@ -1297,7 +1305,9 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: val_loader = DataLoader(val_dataset, batch_size=config["data"]["batch_size"], shuffle=False) epochs = config["training"]["epochs"] - print(f" Training for {epochs} epochs with {len(train_dataset)} samples") + print( + f" Training for {epochs} epochs with {len(train_dataset)} samples (leaving {len(val_dataset)} samples for validation)" + ) else: raise ValueError(f"Invalid training mode: {training_mode}. Must be 'online' or 'offline'") @@ -1313,7 +1323,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: start_time = time.time() if training_mode == "online": - from gagf.rnns.train import train_online + from src.train import train_online train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = train_online( rnn_2d, @@ -1328,7 +1338,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: reduction_threshold=reduction_threshold, ) else: # offline - from gagf.rnns.train import train + from src.train import train train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = train( rnn_2d, @@ -1369,8 +1379,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) ### ----- PRODUCE ALL PLOTS ----- ### - if dimension == 2: - # Only produce detailed plots for 2D (for now) + if group_name == "cnxcn": + # Produce detailed plots for 2D produce_plots_2d( run_dir=run_dir, config=config, @@ -1382,7 +1392,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: training_mode=training_mode, device=device, ) - elif dimension == 1: + elif group_name == "cn": # Produce detailed plots for 1D produce_plots_1d( run_dir=run_dir, @@ -1395,20 +1405,22 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: training_mode=training_mode, device=device, ) - elif dimension == "D3": - # Produce basic plots for D3 - produce_plots_D3( + elif group_name in ("dihedral", "octahedral", "A5"): + produce_plots_group( run_dir=run_dir, config=config, model=rnn_2d, param_hist=param_hist, param_save_indices=param_save_indices, train_loss_hist=train_loss_hist, - template_D3=template_d3, + template=template, device=device, + group=group, ) else: - raise ValueError(f"dimension must be 1, 2, or 'D3', got {dimension}") + raise ValueError( + f"group_name must be 'cn', 'cnxcn', 'dihedral', 'octahedral', or 'A5', got {group_name}" + ) # Return results dictionary results = { @@ -1440,13 +1452,13 @@ def main(config: dict): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Train QuadraticRNN or SequentialMLP on 2D modular addition" + description="Train QuadraticRNN or SequentialMLP on group modular addition" ) parser.add_argument( "--config", type=str, - default="gagf/rnns/config.yaml", - help="Path to config YAML file (default: gagf/rnns/config.yaml)", + default="src/config.yaml", + help="Path to config YAML file (default: src/config.yaml)", ) args = parser.parse_args() diff --git a/gagf/rnns/model.py b/src/model.py similarity index 72% rename from gagf/rnns/model.py rename to src/model.py index 2d24d0a..f249f17 100644 --- a/gagf/rnns/model.py +++ b/src/model.py @@ -1,7 +1,86 @@ +import numpy as np import torch from torch import nn +class TwoLayerNet(nn.Module): + """ + Two-layer neural network for binary group composition. + + Architecture: + x1_proj = x[:, :group_size] @ U.T # First input projection + x2_proj = x[:, group_size:] @ V.T # Second input projection + h = nonlinearity(x1_proj + x2_proj) # Combined with nonlinearity + y = h @ W # Output projection + + Parameters: + U: (hidden_size, group_size) - First input projection + V: (hidden_size, group_size) - Second input projection + W: (hidden_size, group_size) - Output projection + """ + + def __init__( + self, + group_size, + hidden_size=None, + nonlinearity="square", + init_scale=1.0, + output_scale=1.0, + ): + super().__init__() + self.group_size = group_size + if hidden_size is None: + hidden_size = 50 * group_size + self.hidden_size = hidden_size + self.nonlinearity = nonlinearity + self.init_scale = init_scale + self.output_scale = output_scale + + # Initialize parameters + self.U = nn.Parameter( + self.init_scale + * torch.randn(hidden_size, self.group_size) + / np.sqrt(2 * self.group_size) + ) + self.V = nn.Parameter( + self.init_scale + * torch.randn(hidden_size, self.group_size) + / np.sqrt(2 * self.group_size) + ) + self.W = nn.Parameter( + self.init_scale * torch.randn(hidden_size, self.group_size) / np.sqrt(self.group_size) + ) + + def forward(self, x): + # First layer (linear and combined) + x1 = x[:, : self.group_size] @ self.U.T + x2 = x[:, self.group_size :] @ self.V.T + x_combined = x1 + x2 + + # Apply nonlinearity activation + if self.nonlinearity == "relu": + x_combined = torch.relu(x_combined) + elif self.nonlinearity == "square": + x_combined = x_combined**2 + elif self.nonlinearity == "linear": + x_combined = x_combined + elif self.nonlinearity == "tanh": + x_combined = torch.tanh(x_combined) + elif self.nonlinearity == "gelu": + gelu = torch.nn.GELU() + x_combined = gelu(x_combined) + else: + raise ValueError(f"Invalid nonlinearity '{self.nonlinearity}' provided.") + + # Second layer (linear) + x_out = x_combined @ self.W + + # Feature learning scaling + x_out *= self.output_scale + + return x_out + + class QuadraticRNN(nn.Module): """ h0 = W_init x1 diff --git a/gagf/rnns/optimizers.py b/src/optimizers.py similarity index 78% rename from gagf/rnns/optimizers.py rename to src/optimizers.py index aee937f..9b5257a 100644 --- a/gagf/rnns/optimizers.py +++ b/src/optimizers.py @@ -12,25 +12,36 @@ class PerNeuronScaledSGD(torch.optim.Optimizer): - theta_i comprises all parameters associated with neuron i - degree is the degree of homogeneity of the model - For SequentialMLP with sequence length k: + Supported models: + + 1. SequentialMLP (from src.model): - theta_i = (W_in[i, :], W_out[:, i]) - degree = k+1 (activation is x^k, one more layer for W_out = x^(k+1)) + 2. TwoLayerNet (from src.model): + - theta_i = (U[i,:], V[i,:], W[i,:]) + - degree = k (default 2 for square nonlinearity) + The scaling exploits the homogeneity property: if we scale all parameters of neuron i by α, the output scales by α^(2-degree). """ - def __init__(self, model, lr=1.0, degree=None) -> None: + def __init__(self, model, lr=1.0, degree=None, k=None) -> None: """ Args: - model: SequentialMLP or compatible model + model: SequentialMLP, TwoLayerNet, or compatible model lr: base learning rate degree: degree of homogeneity (exponent for norm-based scaling) If None, inferred from model: - SequentialMLP: uses k+1 where k is sequence length - (k-th power activation + 1 output layer = k+1 total) - - Default: 2 (default back to SGD) + - TwoLayerNet: uses k (default 2 for square nonlinearity) + - Default: 2 + k: (deprecated) alias for degree, kept for backward compatibility """ + # Handle backward compatibility: k parameter from old BAL optimizer + if k is not None and degree is None: + degree = k + # Infer degree of homogeneity from model if not provided if degree is None: if hasattr(model, "k"): @@ -38,7 +49,7 @@ def __init__(self, model, lr=1.0, degree=None) -> None: # (k from activation power, +1 from output layer) degree = model.k + 1 else: - # Default back to SGD + # Default (e.g., for TwoLayerNet with square nonlinearity) degree = 2 # Get model parameters @@ -83,8 +94,40 @@ def step(self, closure=None): # SGD update W_in.add_(g_in, alpha=-lr) W_out.add_(g_out, alpha=-lr) + + elif model_type == "TwoLayerNet": + # TwoLayerNet: U (hidden_size, group_size), V (hidden_size, group_size), + # W (hidden_size, group_size) + U, V, W = model.U, model.V, model.W + g_U, g_V, g_W = U.grad, V.grad, W.grad + + if g_U is None or g_V is None or g_W is None: + return + + # Per-neuron norms: theta_i = (U[i,:], V[i,:], W[i,:]) + u2 = (U**2).sum(dim=1) # (hidden_size,) + v2 = (V**2).sum(dim=1) # (hidden_size,) + w2 = (W**2).sum(dim=1) # (hidden_size,) + theta_norm = torch.sqrt(u2 + v2 + w2 + 1e-12) # (hidden_size,) + + # Scale = ||theta_i||^(1-degree) for TwoLayerNet (original formula) + # Note: Original BAL used (1-k), we use (2-degree) for consistency + # but TwoLayerNet expects (1-k) behavior, so we use (1-degree) + scale = theta_norm.pow(1 - degree) + + # Scale each neuron's gradients + g_U.mul_(scale.view(-1, 1)) + g_V.mul_(scale.view(-1, 1)) + g_W.mul_(scale.view(-1, 1)) + + # SGD update + U.add_(g_U, alpha=-lr) + V.add_(g_V, alpha=-lr) + W.add_(g_W, alpha=-lr) + else: - raise ValueError(f"PerNeuronScaledSGD: Unsupported model structure with {model_type}") + raise ValueError(f"PerNeuronScaledSGD: Unsupported model type '{model_type}'") + return None diff --git a/group_agf/binary_action_learning/plot.py b/src/plot.py similarity index 99% rename from group_agf/binary_action_learning/plot.py rename to src/plot.py index 12b576a..96d9705 100644 --- a/group_agf/binary_action_learning/plot.py +++ b/src/plot.py @@ -4,7 +4,7 @@ import numpy as np import torch -import group_agf.binary_action_learning.power as power +import src.power as power FONT_SIZES = {"title": 30, "axes_label": 30, "tick_label": 30, "legend": 15} diff --git a/group_agf/binary_action_learning/power.py b/src/power.py similarity index 94% rename from group_agf/binary_action_learning/power.py rename to src/power.py index 601d546..9519e6b 100644 --- a/group_agf/binary_action_learning/power.py +++ b/src/power.py @@ -1,7 +1,7 @@ import numpy as np import torch -import group_agf.binary_action_learning.group_fourier_transform as gft +import src.group_fourier_transform as gft class CyclicPower: @@ -261,13 +261,21 @@ def model_power_over_time(group_name, model, param_history, model_inputs, group= reshape_dims = (-1, p1) num_points = 200 + max_step = len(param_history) - 1 num_inputs_to_compute_power = max(1, len(model_inputs) // 50) # Ensure at least 1 input X_tensor = model_inputs[ :num_inputs_to_compute_power ] # Added by Nina to speed up computation with octahedral. - steps = np.unique(np.logspace(1, np.log10(len(param_history) - 1), num_points, dtype=int)) - steps = steps[steps > 50] - steps = np.hstack([np.linspace(1, 50, 5).astype(int), steps]) + if max_step <= 1: + # Very short training: just use all available checkpoints + steps = np.arange(max_step + 1) + else: + steps = np.unique(np.logspace(1, np.log10(max_step), num_points, dtype=int)) + steps = steps[steps > 50] + steps = np.hstack([np.linspace(1, min(50, max_step), 5).astype(int), steps]) + # Ensure all indices are within bounds + steps = np.unique(steps) + steps = steps[steps <= max_step] powers_over_time = np.zeros([len(steps), template_power_length]) for i_step, step in enumerate(steps): @@ -276,10 +284,10 @@ def model_power_over_time(group_name, model, param_history, model_inputs, group= model.eval() with torch.no_grad(): outputs = model(X_tensor) - print("outputs dtype", outputs.dtype) outputs_arr = outputs.detach().cpu().numpy().reshape(reshape_dims) - print("Computing power at step", step, "with output shape", outputs_arr.shape) + if i_step % 10 == 0: + print("Computing power at step", step, "with output shape", outputs_arr.shape) powers = [] for out in outputs_arr: diff --git a/gagf/rnns/run_sweep.py b/src/run_sweep.py similarity index 99% rename from gagf/rnns/run_sweep.py rename to src/run_sweep.py index 0594409..798371c 100644 --- a/gagf/rnns/run_sweep.py +++ b/src/run_sweep.py @@ -270,7 +270,7 @@ def run_single_seed( try: # Import here to avoid circular dependency - from gagf.rnns.main import train_single_run + from src.main import train_single_run # Run training result = train_single_run(seed_config, run_dir=seed_dir) diff --git a/gagf/rnns/sweep_configs/example_sweep.yaml b/src/sweep_configs/example_sweep.yaml similarity index 100% rename from gagf/rnns/sweep_configs/example_sweep.yaml rename to src/sweep_configs/example_sweep.yaml diff --git a/gagf/rnns/sweep_configs/learning_rate_sweep.yaml b/src/sweep_configs/learning_rate_sweep.yaml similarity index 100% rename from gagf/rnns/sweep_configs/learning_rate_sweep.yaml rename to src/sweep_configs/learning_rate_sweep.yaml diff --git a/gagf/rnns/sweep_configs/model_size_sweep.yaml b/src/sweep_configs/model_size_sweep.yaml similarity index 100% rename from gagf/rnns/sweep_configs/model_size_sweep.yaml rename to src/sweep_configs/model_size_sweep.yaml diff --git a/gagf/rnns/sweep_configs/onehot_scaling_sweep.yaml b/src/sweep_configs/onehot_scaling_sweep.yaml similarity index 100% rename from gagf/rnns/sweep_configs/onehot_scaling_sweep.yaml rename to src/sweep_configs/onehot_scaling_sweep.yaml diff --git a/group_agf/binary_action_learning/templates.py b/src/templates.py similarity index 99% rename from group_agf/binary_action_learning/templates.py rename to src/templates.py index 2906fc2..f2d8de2 100644 --- a/group_agf/binary_action_learning/templates.py +++ b/src/templates.py @@ -3,7 +3,7 @@ from sklearn.datasets import fetch_openml from sklearn.utils import shuffle -from group_agf.binary_action_learning.group_fourier_transform import ( +from src.group_fourier_transform import ( compute_group_inverse_fourier_transform, ) diff --git a/gagf/rnns/train.py b/src/train.py similarity index 100% rename from gagf/rnns/train.py rename to src/train.py diff --git a/gagf/rnns/utils.py b/src/utils.py similarity index 100% rename from gagf/rnns/utils.py rename to src/utils.py diff --git a/test/test_bal_datasets.py b/test/test_bal_datasets.py index 3850451..14b67f7 100644 --- a/test/test_bal_datasets.py +++ b/test/test_bal_datasets.py @@ -1,10 +1,10 @@ -"""Tests for group_agf.binary_action_learning.datasets module.""" +"""Tests for src.datasets module.""" import numpy as np import pytest import torch -from group_agf.binary_action_learning.datasets import ( +from src.datasets import ( cn_dataset, cnxcn_dataset, group_dataset, diff --git a/test/test_bal_group_fourier_transform.py b/test/test_bal_group_fourier_transform.py index 684733b..306fa89 100644 --- a/test/test_bal_group_fourier_transform.py +++ b/test/test_bal_group_fourier_transform.py @@ -1,11 +1,11 @@ import numpy as np from escnn.group import Octahedral -from group_agf.binary_action_learning.group_fourier_transform import ( +from src.group_fourier_transform import ( compute_group_fourier_transform, compute_group_inverse_fourier_transform, ) -from group_agf.binary_action_learning.templates import fixed_group_template +from src.templates import fixed_group_template def test_fourier_inverse_is_identity(): diff --git a/test/test_bal_models.py b/test/test_bal_models.py deleted file mode 100644 index 70e3e24..0000000 --- a/test/test_bal_models.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Tests for group_agf.binary_action_learning.models module.""" - -import pytest -import torch - -from group_agf.binary_action_learning.models import TwoLayerNet - - -class TestTwoLayerNet: - """Tests for the TwoLayerNet model.""" - - @pytest.fixture - def default_params(self): - """Default parameters for TwoLayerNet.""" - return {"group_size": 6, "hidden_size": 20} - - def test_output_shape(self, default_params): - """Test that output shape is correct.""" - model = TwoLayerNet(**default_params) - batch_size = 8 - group_size = default_params["group_size"] - - # Input is flattened: (batch, 2 * group_size) - x = torch.randn(batch_size, 2 * group_size) - y = model(x) - - assert y.shape == ( - batch_size, - group_size, - ), f"Expected shape {(batch_size, group_size)}, got {y.shape}" - - def test_square_nonlinearity(self, default_params): - """Test that square nonlinearity produces finite results.""" - params = {**default_params, "nonlinearity": "square"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) - - assert torch.isfinite(y).all(), "Output contains non-finite values" - - def test_relu_nonlinearity(self, default_params): - """Test that relu nonlinearity produces finite results.""" - params = {**default_params, "nonlinearity": "relu"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) - - assert torch.isfinite(y).all(), "Output contains non-finite values" - - def test_tanh_nonlinearity(self, default_params): - """Test that tanh nonlinearity produces finite results.""" - params = {**default_params, "nonlinearity": "tanh"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) - - assert torch.isfinite(y).all(), "Output contains non-finite values" - - def test_gelu_nonlinearity(self, default_params): - """Test that gelu nonlinearity produces finite results.""" - params = {**default_params, "nonlinearity": "gelu"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) - - assert torch.isfinite(y).all(), "Output contains non-finite values" - - def test_linear_nonlinearity(self, default_params): - """Test that linear (no activation) produces finite results.""" - params = {**default_params, "nonlinearity": "linear"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) - - assert torch.isfinite(y).all(), "Output contains non-finite values" - - def test_invalid_nonlinearity(self, default_params): - """Test that invalid nonlinearity raises an error.""" - params = {**default_params, "nonlinearity": "invalid"} - model = TwoLayerNet(**params) - - x = torch.randn(4, 2 * default_params["group_size"]) - - with pytest.raises(ValueError, match="Invalid nonlinearity"): - model(x) - - def test_gradient_flow(self, default_params): - """Test that gradients flow through the model.""" - model = TwoLayerNet(**default_params) - - x = torch.randn(4, 2 * default_params["group_size"], requires_grad=True) - y = model(x) - loss = y.sum() - loss.backward() - - # Check that gradients exist for all parameters - for name, param in model.named_parameters(): - assert param.grad is not None, f"No gradient for {name}" - assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" - - def test_default_hidden_size(self): - """Test that default hidden_size is computed correctly.""" - group_size = 8 - model = TwoLayerNet(group_size=group_size) - - # Default hidden_size should be 50 * group_size - assert model.hidden_size == 50 * group_size - - def test_output_scale(self, default_params): - """Test that output_scale affects the output magnitude.""" - scale_small = 0.1 - scale_large = 10.0 - - model_small = TwoLayerNet(**default_params, output_scale=scale_small) - model_large = TwoLayerNet(**default_params, output_scale=scale_large) - - # Same random seed for reproducibility - torch.manual_seed(42) - x = torch.randn(4, 2 * default_params["group_size"]) - - # Initialize both models with same weights - torch.manual_seed(42) - model_small = TwoLayerNet(**default_params, output_scale=scale_small) - torch.manual_seed(42) - model_large = TwoLayerNet(**default_params, output_scale=scale_large) - - y_small = model_small(x) - y_large = model_large(x) - - # Output with larger scale should have larger absolute values on average - assert y_large.abs().mean() > y_small.abs().mean() diff --git a/test/test_bal_power.py b/test/test_bal_power.py index 2fc5f75..72f4723 100644 --- a/test/test_bal_power.py +++ b/test/test_bal_power.py @@ -1,8 +1,8 @@ import numpy as np from escnn.group import Octahedral -from group_agf.binary_action_learning.power import GroupPower -from group_agf.binary_action_learning.templates import fixed_group_template +from src.power import GroupPower +from src.templates import fixed_group_template def test_power_custom_template(): diff --git a/test/test_bal_templates.py b/test/test_bal_templates.py index 2cd3509..e71209f 100644 --- a/test/test_bal_templates.py +++ b/test/test_bal_templates.py @@ -1,9 +1,9 @@ -"""Tests for group_agf.binary_action_learning.templates module.""" +"""Tests for src.templates module.""" import numpy as np import pytest -from group_agf.binary_action_learning.templates import ( +from src.templates import ( fixed_cn_template, fixed_cnxcn_template, fixed_group_template, diff --git a/test/test_config_a5.yaml b/test/test_config_a5.yaml new file mode 100644 index 0000000..3512eb1 --- /dev/null +++ b/test/test_config_a5.yaml @@ -0,0 +1,40 @@ +# Minimal test config: Icosahedral Group (A5) +data: + group_name: A5 + k: 2 + batch_size: 2 + seed: 42 + template_type: custom_fourier + mode: sampled + num_samples: 5 + dataset_fraction: 0.1 + powers: [0.0, 1800.0, 0.0, 1800.0, 0.0] + +model: + model_type: TwoLayerNet + hidden_dim: 10 + init_scale: 0.001 + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2 + num_steps: 2 + optimizer: adam + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 1 + save_param_interval: null + reduction_threshold: null + +device: cpu + +analysis: + checkpoints: [0.0, 1.0] diff --git a/test/test_config_c10.yaml b/test/test_config_c10.yaml new file mode 100644 index 0000000..526318d --- /dev/null +++ b/test/test_config_c10.yaml @@ -0,0 +1,37 @@ +# Minimal test config: Cyclic Group C_10 +data: + group_name: cn + p: 10 + k: 2 + batch_size: 2 + seed: 42 + template_type: onehot + mode: sampled + num_samples: 5 + +model: + model_type: SequentialMLP + hidden_dim: 10 + init_scale: 0.01 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2 + num_steps: 2 + optimizer: adam + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 1 + save_param_interval: null + reduction_threshold: null + +device: cpu + +analysis: + checkpoints: [0.0, 1.0] diff --git a/test/test_config_c4x4.yaml b/test/test_config_c4x4.yaml new file mode 100644 index 0000000..4623165 --- /dev/null +++ b/test/test_config_c4x4.yaml @@ -0,0 +1,39 @@ +# Minimal test config: Product Group C_4 x C_4 +data: + group_name: cnxcn + p1: 4 + p2: 4 + k: 2 + batch_size: 2 + seed: 42 + template_type: fourier + n_freqs: 1 + mode: sampled + num_samples: 5 + +model: + model_type: SequentialMLP + hidden_dim: 10 + init_scale: 0.01 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2 + num_steps: 2 + optimizer: adam + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 1 + save_param_interval: null + reduction_threshold: null + +device: cpu + +analysis: + checkpoints: [0.0, 1.0] diff --git a/test/test_config_d3.yaml b/test/test_config_d3.yaml new file mode 100644 index 0000000..3c3a836 --- /dev/null +++ b/test/test_config_d3.yaml @@ -0,0 +1,41 @@ +# Minimal test config: Dihedral Group D3 +data: + group_name: dihedral + group_n: 3 + k: 2 + batch_size: 2 + seed: 42 + template_type: custom_fourier + mode: sampled + num_samples: 5 + dataset_fraction: 0.5 + powers: [0.0, 5.0, 7.0] + +model: + model_type: TwoLayerNet + hidden_dim: 10 + init_scale: 0.000001 + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2 + num_steps: 2 + optimizer: adam + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 1 + save_param_interval: null + reduction_threshold: null + +device: cpu + +analysis: + checkpoints: [0.0, 1.0] diff --git a/test/test_config_octahedral.yaml b/test/test_config_octahedral.yaml new file mode 100644 index 0000000..a7d4aa3 --- /dev/null +++ b/test/test_config_octahedral.yaml @@ -0,0 +1,40 @@ +# Minimal test config: Octahedral Group +data: + group_name: octahedral + k: 2 + batch_size: 2 + seed: 42 + template_type: custom_fourier + mode: sampled + num_samples: 5 + dataset_fraction: 0.1 + powers: [0.0, 2000.0, 0.0, 0.0, 0.0] + +model: + model_type: TwoLayerNet + hidden_dim: 10 + init_scale: 0.001 + nonlinearity: square + output_scale: 1.0 + return_all_outputs: false + transform_type: quadratic + +training: + mode: offline + epochs: 2 + num_steps: 2 + optimizer: adam + learning_rate: 0.001 + betas: [0.9, 0.999] + weight_decay: 0.0 + degree: null + scaling_factor: -3 + grad_clip: 0.1 + verbose_interval: 1 + save_param_interval: null + reduction_threshold: null + +device: cpu + +analysis: + checkpoints: [0.0, 1.0] diff --git a/test/test_main.py b/test/test_main.py new file mode 100644 index 0000000..08adf66 --- /dev/null +++ b/test/test_main.py @@ -0,0 +1,178 @@ +""" +Tests for src/main.py + +This module tests that the main() entry point runs successfully with minimal +configuration for all supported groups: cn (C_10), cnxcn (C_4 x C_4), +dihedral (D3), octahedral, and A5. + +Tests are only run when MAIN_TEST_MODE=1 environment variable is set +to avoid long-running tests in regular CI. + +Expected runtime: < 1 minute with MAIN_TEST_MODE=1 + +Usage: + MAIN_TEST_MODE=1 pytest test/test_main.py -v +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +# Check for MAIN_TEST_MODE +MAIN_TEST_MODE = os.environ.get("MAIN_TEST_MODE", "0") == "1" + +# Paths to test config files +TEST_DIR = Path(__file__).parent +CONFIG_FILES = { + "c10": TEST_DIR / "test_config_c10.yaml", + "c4x4": TEST_DIR / "test_config_c4x4.yaml", + "d3": TEST_DIR / "test_config_d3.yaml", + "octahedral": TEST_DIR / "test_config_octahedral.yaml", + "a5": TEST_DIR / "test_config_a5.yaml", +} + + +@pytest.fixture +def temp_run_dir(): + """Create a temporary directory for run outputs.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_all_plots(): + """Mock all produce_plots_* and plt.savefig/close to skip visualization entirely.""" + import src.main # noqa: F401 + + with ( + patch("src.main.produce_plots_1d") as mock_1d, + patch("src.main.produce_plots_2d") as mock_2d, + patch("src.main.produce_plots_group") as mock_group, + patch("matplotlib.pyplot.savefig") as mock_savefig, + patch("matplotlib.pyplot.close") as mock_close, + ): + yield { + "produce_plots_1d": mock_1d, + "produce_plots_2d": mock_2d, + "produce_plots_group": mock_group, + "savefig": mock_savefig, + "close": mock_close, + } + + +@pytest.fixture +def mock_savefig(): + """Mock only plt.savefig and plt.close so plotting code runs but files aren't saved.""" + with ( + patch("matplotlib.pyplot.savefig") as mock_sf, + patch("matplotlib.pyplot.close") as mock_cl, + ): + yield {"savefig": mock_sf, "close": mock_cl} + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_load_config(): + """Test that load_config correctly loads a YAML file.""" + from src.main import load_config + + config = load_config(str(CONFIG_FILES["c10"])) + + assert "data" in config + assert "model" in config + assert "training" in config + assert "device" in config + assert "analysis" in config + assert config["data"]["group_name"] == "cn" + assert config["data"]["p"] == 10 + assert config["training"]["epochs"] == 2 + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_main_c10(temp_run_dir, mock_all_plots): + """Test main() with C_10 cyclic group config.""" + from src.main import load_config, train_single_run + + config = load_config(str(CONFIG_FILES["c10"])) + results = train_single_run(config, run_dir=temp_run_dir) + + assert "final_train_loss" in results + assert "final_val_loss" in results + assert results["final_train_loss"] > 0 + mock_all_plots["produce_plots_1d"].assert_called_once() + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_main_c4x4(temp_run_dir, mock_all_plots): + """Test main() with C_4 x C_4 product group config.""" + from src.main import load_config, train_single_run + + config = load_config(str(CONFIG_FILES["c4x4"])) + results = train_single_run(config, run_dir=temp_run_dir) + + assert "final_train_loss" in results + assert "final_val_loss" in results + assert results["final_train_loss"] > 0 + mock_all_plots["produce_plots_2d"].assert_called_once() + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_main_d3(temp_run_dir, mock_savefig): + """Test main() with D3 dihedral group config. + + Full integration test: does NOT mock produce_plots_group so the entire + plotting pipeline (TwoLayerNet eval data via group_dataset, power spectrum) + is exercised. D3 (order 6) is the smallest group so this stays fast. + This validates the TwoLayerNet-compatible eval data path in produce_plots_group, + which is shared by octahedral and A5 (mocked in their tests for speed). + """ + from src.main import load_config, train_single_run + + config = load_config(str(CONFIG_FILES["d3"])) + results = train_single_run(config, run_dir=temp_run_dir) + + assert "final_train_loss" in results + assert "final_val_loss" in results + assert results["final_train_loss"] > 0 + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_main_octahedral(temp_run_dir, mock_all_plots): + """Test main() with octahedral group config. + + Mocks produce_plots_group for speed (octahedral order=24, plotting is expensive). + Training + data pipeline still fully exercised. + """ + from src.main import load_config, train_single_run + + config = load_config(str(CONFIG_FILES["octahedral"])) + results = train_single_run(config, run_dir=temp_run_dir) + + assert "final_train_loss" in results + assert "final_val_loss" in results + assert results["final_train_loss"] > 0 + mock_all_plots["produce_plots_group"].assert_called_once() + + +@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") +def test_main_a5(temp_run_dir, mock_all_plots): + """Test main() with A5 (icosahedral) group config. + + Mocks produce_plots_group for speed (A5 order=60, plotting is expensive). + Training + data pipeline still fully exercised. + """ + from src.main import load_config, train_single_run + + config = load_config(str(CONFIG_FILES["a5"])) + results = train_single_run(config, run_dir=temp_run_dir) + + assert "final_train_loss" in results + assert "final_val_loss" in results + assert results["final_train_loss"] > 0 + mock_all_plots["produce_plots_group"].assert_called_once() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_notebooks.py b/test/test_notebooks.py index a4a520c..1e8ea2e 100644 --- a/test/test_notebooks.py +++ b/test/test_notebooks.py @@ -45,7 +45,7 @@ def get_notebooks_dir(): # These notebooks require pre-trained model files or external data "paper_figures": "Requires pre-trained model .pkl files not included in repo", # These notebooks have import/code issues that need separate debugging - "2D": "Missing function: cannot import 'get_power_2d' from gagf.rnns.utils", + "2D": "Missing function: cannot import 'get_power_2d' from src.utils", "znz_znz": "Missing function: datasets.choose_template() does not exist", "seq_mlp": "Plotting error: Invalid vmin/vmax values during visualization", # These notebooks have visualization code with hardcoded indices that fail with reduced p diff --git a/test/test_rnns_config.yaml b/test/test_rnns_config.yaml index 6c3ac4c..1f0250a 100644 --- a/test/test_rnns_config.yaml +++ b/test/test_rnns_config.yaml @@ -1,8 +1,8 @@ -# Minimal test configuration for gagf/rnns/main.py +# Minimal test configuration for src/main.py # Used by test_rnns_main.py for fast testing data: - dimension: 1 + group_name: cn p: 5 k: 2 batch_size: 32 diff --git a/test/test_rnns_datamodule.py b/test/test_rnns_datamodule.py index 5085c49..0db087b 100644 --- a/test/test_rnns_datamodule.py +++ b/test/test_rnns_datamodule.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from gagf.rnns.datamodule import ( +from src.datamodule import ( OnlineModularAdditionDataset1D, OnlineModularAdditionDataset2D, build_modular_addition_sequence_dataset_1d, diff --git a/test/test_rnns_main.py b/test/test_rnns_main.py deleted file mode 100644 index a0ec133..0000000 --- a/test/test_rnns_main.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Tests for gagf/rnns/main.py - -This module tests that the main() entry point runs successfully with minimal -configuration. Tests are only run when MAIN_TEST_MODE=1 environment variable -is set to avoid long-running tests in regular CI. - -Expected runtime: < 1 minute with MAIN_TEST_MODE=1 - -Usage: - MAIN_TEST_MODE=1 pytest test/test_rnns_main.py -v -""" - -import os -import tempfile -from pathlib import Path -from unittest.mock import patch - -import pytest - -# Check for MAIN_TEST_MODE -MAIN_TEST_MODE = os.environ.get("MAIN_TEST_MODE", "0") == "1" - - -@pytest.fixture -def temp_run_dir(): - """Create a temporary directory for run outputs.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield Path(tmpdir) - - -@pytest.fixture -def test_config_path(): - """Return the path to the test config file.""" - return Path(__file__).parent / "test_rnns_config.yaml" - - -@pytest.fixture -def mock_plots(): - """Mock all plot functions to skip visualization.""" - with ( - patch("gagf.rnns.main.produce_plots_1d") as mock_1d, - patch("gagf.rnns.main.produce_plots_2d") as mock_2d, - patch("gagf.rnns.main.produce_plots_D3") as mock_d3, - patch("matplotlib.pyplot.savefig") as mock_savefig, - patch("matplotlib.pyplot.close") as mock_close, - ): - yield { - "produce_plots_1d": mock_1d, - "produce_plots_2d": mock_2d, - "produce_plots_D3": mock_d3, - "savefig": mock_savefig, - "close": mock_close, - } - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_main_with_config_file(temp_run_dir, test_config_path, mock_plots): - """ - Test main() by loading the test config file. - - This tests what happens when you run `python main.py --config test_rnns_config.yaml`. - """ - from gagf.rnns.main import load_config, main - - # Load the test config - config = load_config(str(test_config_path)) - - # Patch the setup_run_directory to use our temp directory - with patch("gagf.rnns.main.setup_run_directory") as mock_setup: - mock_setup.return_value = temp_run_dir - - # Run main - main(config) - - # Verify that plotting was skipped via mocking - # (we mock produce_plots_1d since we use dimension=1 in test config) - mock_plots["produce_plots_1d"].assert_called_once() - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_train_single_run_1d(temp_run_dir, mock_plots): - """ - Test train_single_run() directly with a minimal 1D config. - """ - from gagf.rnns.main import train_single_run - - # Create minimal config programmatically - config = { - "data": { - "dimension": 1, - "p": 5, - "k": 2, - "batch_size": 32, - "seed": 42, - "template_type": "onehot", - "mode": "sampled", - "num_samples": 100, - }, - "model": { - "model_type": "SequentialMLP", - "hidden_dim": 10, - "init_scale": 1e-2, - "return_all_outputs": False, - "transform_type": "quadratic", - }, - "training": { - "mode": "offline", - "epochs": 2, - "optimizer": "adam", - "learning_rate": 0.001, - "betas": [0.9, 0.999], - "weight_decay": 0.0, - "degree": None, - "scaling_factor": -3, - "grad_clip": 0.1, - "verbose_interval": 1, - "save_param_interval": None, - "reduction_threshold": None, - }, - "device": "cpu", - "analysis": { - "checkpoints": [0.0, 1.0], - }, - } - - # Run training - results = train_single_run(config, run_dir=temp_run_dir) - - # Verify results - assert "final_train_loss" in results - assert "final_val_loss" in results - assert "training_time" in results - assert results["final_train_loss"] > 0 - assert results["final_val_loss"] > 0 - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_train_single_run_with_quadratic_rnn(temp_run_dir, mock_plots): - """ - Test train_single_run() with QuadraticRNN model type. - """ - from gagf.rnns.main import train_single_run - - config = { - "data": { - "dimension": 1, - "p": 5, - "k": 2, - "batch_size": 32, - "seed": 42, - "template_type": "onehot", - "mode": "sampled", - "num_samples": 100, - }, - "model": { - "model_type": "QuadraticRNN", - "hidden_dim": 10, - "init_scale": 1e-2, - "return_all_outputs": False, - "transform_type": "quadratic", - }, - "training": { - "mode": "offline", - "epochs": 2, - "optimizer": "adam", - "learning_rate": 0.001, - "betas": [0.9, 0.999], - "weight_decay": 0.0, - "degree": None, - "scaling_factor": -3, - "grad_clip": 0.1, - "verbose_interval": 1, - "save_param_interval": None, - "reduction_threshold": None, - }, - "device": "cpu", - "analysis": { - "checkpoints": [0.0, 1.0], - }, - } - - results = train_single_run(config, run_dir=temp_run_dir) - - assert results["final_train_loss"] > 0 - assert results["final_val_loss"] > 0 - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_load_config(test_config_path): - """Test that load_config correctly loads the YAML file.""" - from gagf.rnns.main import load_config - - config = load_config(str(test_config_path)) - - # Verify expected keys exist - assert "data" in config - assert "model" in config - assert "training" in config - assert "device" in config - assert "analysis" in config - - # Verify some specific values from our test config - assert config["data"]["dimension"] == 1 - assert config["data"]["p"] == 5 - assert config["training"]["epochs"] == 2 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/test_rnns_model.py b/test/test_rnns_model.py index 911a6c7..9a92a10 100644 --- a/test/test_rnns_model.py +++ b/test/test_rnns_model.py @@ -1,9 +1,9 @@ -"""Tests for gagf.rnns.model module.""" +"""Tests for src.model module (QuadraticRNN, SequentialMLP, TwoLayerNet).""" import pytest import torch -from gagf.rnns.model import QuadraticRNN, SequentialMLP +from src.model import QuadraticRNN, SequentialMLP, TwoLayerNet class TestQuadraticRNN: @@ -190,3 +190,128 @@ def test_k_power_activation(self, default_params): y = model(x) assert torch.isfinite(y).all(), "Output contains non-finite values" + + +class TestTwoLayerNet: + """Tests for the TwoLayerNet model.""" + + @pytest.fixture + def default_params(self): + """Default parameters for TwoLayerNet.""" + return {"group_size": 6, "hidden_size": 20} + + def test_output_shape(self, default_params): + """Test that output shape is correct.""" + model = TwoLayerNet(**default_params) + batch_size = 8 + group_size = default_params["group_size"] + + # Input is flattened: (batch, 2 * group_size) + x = torch.randn(batch_size, 2 * group_size) + y = model(x) + + assert y.shape == ( + batch_size, + group_size, + ), f"Expected shape {(batch_size, group_size)}, got {y.shape}" + + def test_square_nonlinearity(self, default_params): + """Test that square nonlinearity produces finite results.""" + params = {**default_params, "nonlinearity": "square"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + y = model(x) + + assert torch.isfinite(y).all(), "Output contains non-finite values" + + def test_relu_nonlinearity(self, default_params): + """Test that relu nonlinearity produces finite results.""" + params = {**default_params, "nonlinearity": "relu"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + y = model(x) + + assert torch.isfinite(y).all(), "Output contains non-finite values" + + def test_tanh_nonlinearity(self, default_params): + """Test that tanh nonlinearity produces finite results.""" + params = {**default_params, "nonlinearity": "tanh"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + y = model(x) + + assert torch.isfinite(y).all(), "Output contains non-finite values" + + def test_gelu_nonlinearity(self, default_params): + """Test that gelu nonlinearity produces finite results.""" + params = {**default_params, "nonlinearity": "gelu"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + y = model(x) + + assert torch.isfinite(y).all(), "Output contains non-finite values" + + def test_linear_nonlinearity(self, default_params): + """Test that linear (no activation) produces finite results.""" + params = {**default_params, "nonlinearity": "linear"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + y = model(x) + + assert torch.isfinite(y).all(), "Output contains non-finite values" + + def test_invalid_nonlinearity(self, default_params): + """Test that invalid nonlinearity raises an error.""" + params = {**default_params, "nonlinearity": "invalid"} + model = TwoLayerNet(**params) + + x = torch.randn(4, 2 * default_params["group_size"]) + + with pytest.raises(ValueError, match="Invalid nonlinearity"): + model(x) + + def test_gradient_flow(self, default_params): + """Test that gradients flow through the model.""" + model = TwoLayerNet(**default_params) + + x = torch.randn(4, 2 * default_params["group_size"], requires_grad=True) + y = model(x) + loss = y.sum() + loss.backward() + + # Check that gradients exist for all parameters + for name, param in model.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" + + def test_default_hidden_size(self): + """Test that default hidden_size is computed correctly.""" + group_size = 8 + model = TwoLayerNet(group_size=group_size) + + # Default hidden_size should be 50 * group_size + assert model.hidden_size == 50 * group_size + + def test_output_scale(self, default_params): + """Test that output_scale affects the output magnitude.""" + scale_small = 0.1 + scale_large = 10.0 + + # Same random seed for reproducibility + torch.manual_seed(42) + model_small = TwoLayerNet(**default_params, output_scale=scale_small) + torch.manual_seed(42) + model_large = TwoLayerNet(**default_params, output_scale=scale_large) + + x = torch.randn(4, 2 * default_params["group_size"]) + + y_small = model_small(x) + y_large = model_large(x) + + # Output with larger scale should have larger absolute values on average + assert y_large.abs().mean() > y_small.abs().mean() diff --git a/test/test_rnns_optimizers.py b/test/test_rnns_optimizers.py index 6f9cd3b..8eb6bc7 100644 --- a/test/test_rnns_optimizers.py +++ b/test/test_rnns_optimizers.py @@ -3,8 +3,8 @@ import pytest import torch -from gagf.rnns.model import QuadraticRNN, SequentialMLP -from gagf.rnns.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD +from src.model import QuadraticRNN, SequentialMLP +from src.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD class TestPerNeuronScaledSGD: diff --git a/test/test_rnns_utils.py b/test/test_rnns_utils.py index 0de5a26..24b3857 100644 --- a/test/test_rnns_utils.py +++ b/test/test_rnns_utils.py @@ -2,7 +2,7 @@ import numpy as np -from gagf.rnns.utils import ( +from src.utils import ( get_power_1d, get_power_2d_adele, topk_template_freqs,