From 1fc515af92273515ef42ade32854d3ab25f4d9b3 Mon Sep 17 00:00:00 2001 From: Nina Miolane Date: Fri, 6 Feb 2026 21:04:29 +0000 Subject: [PATCH 1/2] consolidate files, make function names consistents, remove outdated functions --- group_agf/__init__.py | 0 group_agf/binary_action_learning/__init__.py | 0 .../binary_action_learning/default_config.py | 67 - group_agf/binary_action_learning/main.py | 288 --- group_agf/binary_action_learning/train.py | 267 --- src/{datamodule.py => dataset.py} | 652 ++----- src/datasets.py | 213 -- src/fourier.py | 63 + src/group_fourier_transform.py | 138 -- src/main.py | 459 +---- src/{optimizers.py => optimizer.py} | 0 src/plot.py | 546 ------ src/power.py | 224 ++- src/run_sweep.py | 4 +- src/template.py | 667 +++++++ src/templates.py | 206 -- src/utils.py | 1737 ----------------- src/viz.py | 954 +++++++++ test/test_bal_datasets.py | 149 -- test/test_bal_group_fourier_transform.py | 49 - test/test_bal_main.py | 204 -- test/test_bal_power.py | 26 - test/test_bal_templates.py | 156 -- ...test_rnns_config.yaml => test_config.yaml} | 2 +- ...est_rnns_datamodule.py => test_dataset.py} | 179 +- test/test_default_config.py | 66 - test/test_fourier.py | 36 + test/test_main.py | 36 +- test/{test_rnns_model.py => test_model.py} | 131 +- test/test_notebooks.py | 2 +- ...t_rnns_optimizers.py => test_optimizer.py} | 77 +- test/{test_rnns_utils.py => test_power.py} | 134 +- test/test_template.py | 143 ++ 33 files changed, 2681 insertions(+), 5194 deletions(-) delete mode 100644 group_agf/__init__.py delete mode 100644 group_agf/binary_action_learning/__init__.py delete mode 100644 group_agf/binary_action_learning/default_config.py delete mode 100644 group_agf/binary_action_learning/main.py delete mode 100644 group_agf/binary_action_learning/train.py rename src/{datamodule.py => dataset.py} (54%) delete mode 100644 src/datasets.py create mode 100644 src/fourier.py delete mode 100644 src/group_fourier_transform.py rename src/{optimizers.py => optimizer.py} (100%) delete mode 100644 src/plot.py create mode 100644 src/template.py delete mode 100644 src/templates.py delete mode 100644 src/utils.py create mode 100644 src/viz.py delete mode 100644 test/test_bal_datasets.py delete mode 100644 test/test_bal_group_fourier_transform.py delete mode 100644 test/test_bal_main.py delete mode 100644 test/test_bal_power.py delete mode 100644 test/test_bal_templates.py rename test/{test_rnns_config.yaml => test_config.yaml} (93%) rename test/{test_rnns_datamodule.py => test_dataset.py} (58%) delete mode 100644 test/test_default_config.py create mode 100644 test/test_fourier.py rename test/{test_rnns_model.py => test_model.py} (73%) rename test/{test_rnns_optimizers.py => test_optimizer.py} (59%) rename test/{test_rnns_utils.py => test_power.py} (53%) create mode 100644 test/test_template.py diff --git a/group_agf/__init__.py b/group_agf/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/group_agf/binary_action_learning/__init__.py b/group_agf/binary_action_learning/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/group_agf/binary_action_learning/default_config.py b/group_agf/binary_action_learning/default_config.py deleted file mode 100644 index 25be107..0000000 --- a/group_agf/binary_action_learning/default_config.py +++ /dev/null @@ -1,67 +0,0 @@ -# Dataset Parameters -group_name = "cnxcn" # , "A5"] # , 'octahedral', 'cn', 'dihedral', 'cnxcn' 'A5'] -group_n = [6] # n in Dn [3, 4, 5] -template_type = "one_hot" # "one_hot", "irrep_construction"] - -powers = { - "cn": [[0, 12.5, 10, 7.5, 5, 2.5]], - "cnxcn": [[0, 12, 10, 8, 6, 4]], - "dihedral": [[0.0, 5.0, 0.0, 7.0, 0.0, 0.0]], # D6: [1,1,2,2,1,1], D5: [1,1,2,2], D3: [1,1,2] - "octahedral": [ # [1, 3, 3, 2, 1] - [0.0, 2000.0, 0.0, 0.0, 0.0], - ], - "A5": [ # [1, 3, 5, 3, 4] - [0.0, 1800.0, 0.0, 1800.0, 0.0], # 3:900, 3:900 - [0.0, 900.0, 0.0, 0.0, 1600.0], # 3:900, 4:1600 - [0.0, 0.0, 2500.0, 900.0, 0.0], # 5:2500, 3:900 - [0.0, 0.0, 2500.0, 0.0, 1600.0], # 5:2500, 4:1600 - ], -} - -# Model Parameters -hidden_factor = [30] # 20, 30, 40, 50] # hidden size = hidden_factor * group_size - -# Learning Parameters -seed = [10] -init_scale = { - "cn": [1e-2], - "cnxcn": [1e-2], - "dihedral": [1e-6], - "octahedral": [1e-3], - "A5": [1e-3], -} -lr = { - "cn": [0.01], - "cnxcn": [0.01], - "dihedral": [0.01], - "octahedral": [0.0001], - "A5": [0.0001], -} - -mom = [0.9] -optimizer_name = ["PerNeuronScaledSGD"] -epochs = [2] # [1000] # , 50000] -verbose_interval = 100 -checkpoint_interval = 200000 -batch_size = [128] # 128, 256] - -# plotting parameters -power_logscale = False - -# Change these if you want to resume training from a checkpoint -resume_from_checkpoint = True -checkpoint_epoch = 5000 - -# cnxcn specific parameters -image_length = [5] - -dataset_fraction = { - "cn": 1.0, - "cnxcn": 1.0, - "dihedral": 1.0, - "octahedral": 1.0, - "A5": 1.0, # [0.2, 0.3, 0.4, 0.5, 0.6] -} - -# model_save_dir = "/tmp/nmiolane/" -model_save_dir = "/tmp/nmiolane/" diff --git a/group_agf/binary_action_learning/main.py b/group_agf/binary_action_learning/main.py deleted file mode 100644 index f605f83..0000000 --- a/group_agf/binary_action_learning/main.py +++ /dev/null @@ -1,288 +0,0 @@ -import datetime -import itertools -import logging -import time - -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -import wandb -from escnn.group import * -from torch.utils.data import DataLoader, TensorDataset - -import default_config -import group_agf.binary_action_learning.train as train -import src.datasets as datasets -import src.model as models -import src.plot as plot -import src.power as power -from src.optimizers import PerNeuronScaledSGD - -today = datetime.date.today() - - -def main_run(config): - """Run regression experiments.""" - full_run = True - print(f"run_start: {today}") - wandb.init( - project="gagf", - tags=[ - f"{today}", - f"run_start_{config['run_start_time']}", - ], - ) - wandb_config = wandb.config - wandb_config.update(config) - - run_name = f"run_{wandb.run.id}" - wandb.run.name = run_name - try: - logging.info(f"\n\n---> START run: {run_name}.") - - print("Generating dataset...") - X, Y, template = datasets.load_dataset(config) - assert len(template) == config["group_size"], "Template size does not match group size." - - if config["group_name"] == "cnxcn": - template_power = power.CyclicPower(template, template_dim=2) - elif config["group_name"] == "cn": - template_power = power.CyclicPower(template, template_dim=1) - else: - template_power = power.GroupPower(template, group=config["group"]) - - print(f"Template powers:\n {template_power.power}") - - X, Y, device = datasets.move_dataset_to_device_and_flatten(X, Y, device=None) - - # Determine batch size: if 'full', set to all samples - if config["batch_size"] == "full": - config["batch_size"] = X.shape[0] - - if default_config.resume_from_checkpoint: - config["checkpoint_path"] = train.get_model_save_path( - config, - checkpoint_epoch=default_config.checkpoint_epoch, - ) - config["run_name"] = run_name - dataset = TensorDataset(X, Y) - dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False) - - np.random.seed(config["seed"]) - torch.manual_seed(config["seed"]) - torch.cuda.manual_seed_all(config["seed"]) # if using GPU - - model = models.TwoLayerNet( - group_size=config["group_size"], - hidden_size=config["hidden_factor"] * config["group_size"], - nonlinearity="square", - init_scale=config["init_scale"], - output_scale=1e0, - ) - model = model.to(device) - loss = nn.MSELoss() - - if config["optimizer_name"] == "Adam": - optimizer = optim.Adam( - model.parameters(), lr=config["lr"], betas=(config["mom"], 0.999) - ) - elif config["optimizer_name"] == "SGD": - optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config["mom"]) - elif config["optimizer_name"] == "PerNeuronScaledSGD": - optimizer = PerNeuronScaledSGD(model, lr=config["lr"]) - else: - raise ValueError( - f"Unknown optimizer: {config['optimizer_name']}. Expected one of ['Adam', 'SGD', 'PerNeuronScaledSGD']." - ) - - print("Starting training...") - loss_history, _, param_history = train.train( - config, - model, - dataloader, - loss, - optimizer, - ) - - print("Training Complete. Generating plots...") - - loss_plot = plot.plot_loss_curve( - loss_history, - template_power, - save_path=config["model_save_dir"] + f"loss_plot_{run_name}.svg", - show=False, - ) - - power_over_training_plot = plot.plot_training_power_over_time( - template_power, - model, - device, - param_history, - X, - config["group_name"], - save_path=config["model_save_dir"] + f"power_over_training_plot_{run_name}.svg", - show=False, - logscale=config["power_logscale"], - ) - print( - f"loss plot and power over training plot saved to {config['model_save_dir']}" - f" at loss_plot_{run_name}.svg and power_over_training_plot_{run_name}.svg" - ) - neuron_weights_plot = plot.plot_neuron_weights( - config, - model, - neuron_indices=None, - ) - - model_predictions_plot = plot.plot_model_outputs( - config["group_name"], config["group_size"], model, X, Y, idx=13 - ) - - wandb.log( - { - "loss_plot": wandb.Image(loss_plot), - # "irreps_plot": wandb.Image(irreps_plot), - "power_over_training_plot": wandb.Image(power_over_training_plot), - "neuron_weights_plot": wandb.Image(neuron_weights_plot), - "model_predictions_plot": wandb.Image(model_predictions_plot), - } - ) - - print("Plots generated and logged to wandb.") - if config["group_name"] not in ("cnxcn", "cn"): - print(f"With irreps' sizes:\n {[irrep.size for irrep in config['group'].irreps()]}") - - wandb_config.update({"full_run": full_run}) - wandb.finish() - - except Exception as e: - full_run = False - wandb_config.update({"full_run": full_run}) - logging.exception(e) - wandb.finish() - - -def main(): - """Parse the default_config file and launch all experiments. - - This launches experiments with wandb with different config parameters. - """ - run_start_time = time.strftime("%m-%d_%H-%M-%S") - for ( - init_scale, - hidden_factor, - seed, - lr, - mom, - optimizer_name, - batch_size, - epochs, - powers, - ) in itertools.product( - default_config.init_scale[default_config.group_name], - default_config.hidden_factor, - default_config.seed, - default_config.lr[default_config.group_name], - default_config.mom, - default_config.optimizer_name, - default_config.batch_size, - default_config.epochs, - default_config.powers[default_config.group_name], - ): - group_name = default_config.group_name - - main_config = { - "group_name": group_name, - "init_scale": init_scale, - "run_start_time": run_start_time, - "hidden_factor": hidden_factor, - "seed": seed, - "lr": lr, - "mom": mom, - "optimizer_name": optimizer_name, - "batch_size": batch_size, - "epochs": epochs, - "verbose_interval": default_config.verbose_interval, - "model_save_dir": default_config.model_save_dir, - "powers": powers, - "dataset_fraction": default_config.dataset_fraction[group_name], - "power_logscale": default_config.power_logscale, - "resume_from_checkpoint": default_config.resume_from_checkpoint, - "checkpoint_interval": default_config.checkpoint_interval, - "checkpoint_path": None, - "template_type": default_config.template_type, - } - - if group_name == "cnxcn": - for (image_length,) in itertools.product( - default_config.image_length, - ): - group_size = image_length * image_length - main_config["group_size"] = group_size - main_config["image_length"] = image_length - main_config["dataset_fraction"] = default_config.dataset_fraction["cnxcn"] - main_config["fourier_coef_diag_values"] = main_config["powers"] - main_run(main_config) - - elif group_name == "cn": - for (group_n,) in itertools.product(default_config.group_n): - main_config["group_size"] = group_n - main_config["group_n"] = group_n - main_config["dataset_fraction"] = default_config.dataset_fraction["cn"] - main_config["fourier_coef_diag_values"] = main_config["powers"] - main_run(main_config) - - elif group_name == "octahedral": - group = Octahedral() - group_size = group.order() - irreps = group.irreps() - irrep_dims = [ir.size for ir in irreps] - print(f"Running for group: {group_name}{group_n} with irrep dims {irrep_dims}") - main_config["group"] = group - main_config["group_size"] = group_size - main_config["fourier_coef_diag_values"] = [ - np.sqrt(group_size * p / dim**2) - for p, dim in zip(main_config["powers"], irrep_dims) - ] - main_run(main_config) - - elif group_name == "A5": - group = Icosahedral() - group_size = group.order() - irreps = group.irreps() - irrep_dims = [ir.size for ir in irreps] - print(f"Running for group: {group_name}{group_n} with irrep dims {irrep_dims}") - main_config["group"] = group - main_config["group_size"] = group_size - main_config["fourier_coef_diag_values"] = [ - np.sqrt(group_size * p / dim**2) - for p, dim in zip(main_config["powers"], irrep_dims) - ] - main_run(main_config) - - else: - for (group_n,) in itertools.product(default_config.group_n): - if group_name == "dihedral": - group = DihedralGroup(group_n) - else: - raise ValueError( - f"Unknown group_name: {group_name}. " - f"Expected one of ['dihedral', 'cn', 'octahedral']." - ) - group_size = group.order() - irreps = group.irreps() - irrep_dims = [ir.size for ir in irreps] - print(f"Running for group: {group_name}{group_n} with irrep dims {irrep_dims}") - main_config["group"] = group - main_config["group_size"] = group_size - main_config["group_n"] = group_n - main_config["dataset_fraction"] = default_config.dataset_fraction[group_name] - main_config["fourier_coef_diag_values"] = [ - np.sqrt(group_size * p / dim**2) - for p, dim in zip(main_config["powers"], irrep_dims) - ] - main_run(main_config) - - -main() diff --git a/group_agf/binary_action_learning/train.py b/group_agf/binary_action_learning/train.py deleted file mode 100644 index bb5ec0c..0000000 --- a/group_agf/binary_action_learning/train.py +++ /dev/null @@ -1,267 +0,0 @@ -import os -import pickle - -import torch - - -def test_accuracy(model, dataloader): - correct = 0 - total = 0 - - with torch.no_grad(): # Disable gradient calculation for evaluation - for batch_idx, (inputs, labels) in enumerate(dataloader): - inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers - outputs = model(inputs) - _, predicted = torch.max(outputs, 1) # Get the index of the largest value (class) - _, true_labels = torch.max(labels, 1) # Get the true class from the one-hot encoding - correct += (predicted == true_labels).sum().item() - total += labels.size(0) - - accuracy = 100 * correct / total - return accuracy - - -def get_model_save_path(config, checkpoint_epoch): - """Generate a unique model save path based on the config parameters.""" - model_save_path = ( - f"{config['model_save_dir']}model_" - f"group_name{config['group_name']}_" - f"group_size{config['group_size']}_" - f"template_type{config['template_type']}_" - f"frac{config['dataset_fraction']}_" - f"init{config['init_scale']}_" - f"lr{config['lr']}_" - f"mom{config['mom']}_" - f"bs{config['batch_size']}_" - f"checkpoint_epoch{checkpoint_epoch}_" - f"seed{config['seed']}.pt" - ) - - return model_save_path - - -def save_param_history(param_history_path, param_history): - """Save param_history separately for analysis (can be very large).""" - torch.save({"param_history": param_history}, param_history_path) - print( - f"Parameter history saved to {param_history_path} " - f"(size: {os.path.getsize(param_history_path) / (1024**3):.2f} GB)" - ) - - -def load_param_history(param_history_path): - """Load param_history separately for analysis (can be very large).""" - param_history = torch.load(param_history_path)["param_history"] - return param_history - - -def save_checkpoint( - checkpoint_path, - model, - optimizer, - epoch, - loss_history, - accuracy_history, - param_history=None, - save_param_history=True, -): - # Get optimizer state dict and remove model reference from param_groups if present - # (model reference can't be pickled and will be restored from the model parameter during load) - optimizer_state = optimizer.state_dict() - if "param_groups" in optimizer_state: - # Create a copy without the model reference - param_groups_clean = [] - for group in optimizer_state["param_groups"]: - group_clean = {k: v for k, v in group.items() if k != "model"} - param_groups_clean.append(group_clean) - optimizer_state_clean = { - "state": optimizer_state["state"], - "param_groups": param_groups_clean, - } - else: - optimizer_state_clean = optimizer_state - - # Build checkpoint dict - only include param_history if requested (saves disk space) - checkpoint_dict = { - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer_state_clean, - "epoch": epoch, - "loss_history": loss_history, - "accuracy_history": accuracy_history, - } - if save_param_history and param_history is not None: - checkpoint_dict["param_history"] = param_history - - try: - torch.save(checkpoint_dict, checkpoint_path) - print( - f"Training history saved to {checkpoint_path}. You can reload it later with torch.load({checkpoint_path}, map_location='cpu')." - ) - except RuntimeError as e: - if "No space left on device" in str(e) or "unexpected pos" in str(e): - print(f"ERROR: Failed to save checkpoint due to disk space issues: {e}") - print(f"Checkpoint path: {checkpoint_path}") - print("Consider cleaning up old checkpoints or using a different save location.") - raise - else: - raise - - -def load_checkpoint(checkpoint_path, model, optimizer=None, map_location="cpu"): - # Try loading with torch.load first (for .pt files or new .pkl files saved with torch.save) - # If that fails, try pickle.load for backward compatibility with old .pkl files - try: - checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) - except Exception as e: - # Fallback to pickle for old checkpoints - print(f"Warning: torch.load failed, trying pickle.load for backward compatibility: {e}") - - with open(checkpoint_path, "rb") as f: - checkpoint = pickle.load(f) - - model.load_state_dict(checkpoint["model_state_dict"]) - if optimizer is not None and "optimizer_state_dict" in checkpoint: - # For PerNeuronScaledSGD, we need to restore the model reference in param_groups - optimizer_state = checkpoint["optimizer_state_dict"] - # Restore model reference in param_groups if it was removed during save - if "param_groups" in optimizer_state and len(optimizer_state["param_groups"]) > 0: - # Check if optimizer expects a model reference (e.g., PerNeuronScaledSGD) - if hasattr(optimizer, "param_groups") and len(optimizer.param_groups) > 0: - if "model" in optimizer.param_groups[0]: - # Restore model reference before loading state dict - for group in optimizer_state["param_groups"]: - group["model"] = model - try: - optimizer.load_state_dict(optimizer_state) - except Exception as e: - print(f"Warning: Could not fully load optimizer state: {e}") - print("Optimizer will continue with current configuration.") - print(f"Loaded checkpoint from {checkpoint_path} (epoch {checkpoint.get('epoch', -1)})") - return checkpoint - - -def train( - config, - model, - dataloader, - criterion, - optimizer, -): - """Train the model with checkpointing and resume capability. - - Parameters: - ---------- - config : dict - Configuration dictionary with training parameters. - model : torch.nn.Module - The neural network model to train. - dataloader : torch.utils.data.DataLoader - DataLoader providing the training data. - criterion : torch.nn.Module - Loss function. - optimizer : torch.optim.Optimizer - Optimizer for training. - Returns: - ------- - loss_history : list - List of loss values for each epoch. - accuracy_history : list - List of accuracy values for each epoch. - param_history : list - List of model parameters for each epoch. - """ - - model.train() - start_epoch = 0 - loss_history = [] - accuracy_history = [] - param_history = [] - - if ( - config["resume_from_checkpoint"] - and config["checkpoint_path"] is not None - and os.path.isfile(config["checkpoint_path"]) - ): - print(f"Resuming from checkpoint at {config['checkpoint_path']}.") - checkpoint = load_checkpoint(config["checkpoint_path"], model, optimizer) - param_history = load_param_history( - config["checkpoint_path"].replace(".pt", "_param_history.pt") - ) - start_epoch = checkpoint.get("epoch", 0) + 1 - loss_history = checkpoint.get("loss_history", []) - accuracy_history = checkpoint.get("accuracy_history", []) - print(f"Resuming training from epoch {start_epoch}") - else: - print( - f"Starting training from scratch (no checkpoint to resume). Checkpoint path was: {config['checkpoint_path']}" - ) - - for epoch in range(start_epoch, config["epochs"]): - running_loss = 0.0 - for inputs, labels in dataloader: - inputs = inputs.view(inputs.shape[0], -1) # Flatten input for FC layers - - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - running_loss += loss.item() - - # Append the average loss for the epoch to loss_history - avg_loss = running_loss / len(dataloader) - # Detect NaN loss early in training and raise an error - if torch.isnan(torch.tensor(avg_loss)): - if epoch < 0.75 * config["epochs"]: - raise RuntimeError( - f"NaN loss encountered at epoch {epoch + 1} (avg_loss={avg_loss})." - ) - loss_history.append(avg_loss) - - # Append accuracy - model.eval() - accuracy = test_accuracy(model, dataloader) - accuracy_history.append(accuracy) - model.train() - - # Save current model parameters - current_params = { - "U": model.U.detach().cpu().clone(), - "V": model.V.detach().cpu().clone(), - "W": model.W.detach().cpu().clone(), - } - param_history.append(current_params) - - # Print verbose information every `verbose_interval` epochs - if (epoch + 1) % config["verbose_interval"] == 0: - print( - f"Epoch {epoch + 1}/{config['epochs']}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%" - ) - - # Save checkpoint if at checkpoint interval or at the end of the training - if (epoch + 1) % config["checkpoint_interval"] == 0 or (epoch + 1) == config["epochs"]: - checkpoint_path = get_model_save_path(config, checkpoint_epoch=(epoch + 1)) - # Only save param_history in the final checkpoint to save disk space - # (param_history can be very large as it stores all parameters for every epoch) - is_final_checkpoint = (epoch + 1) == config["epochs"] - save_checkpoint( - checkpoint_path, - model, - optimizer, - epoch, - loss_history, - accuracy_history, - param_history, - save_param_history=is_final_checkpoint, - ) - - # Save param_history separately only at the end of training (it can be very large) - if (epoch + 1) == config["epochs"]: - param_history_path = checkpoint_path.replace(".pt", "_param_history.pt") - save_param_history(param_history_path, param_history) - - return ( - loss_history, - accuracy_history, - param_history, - ) # Return loss history for plotting diff --git a/src/datamodule.py b/src/dataset.py similarity index 54% rename from src/datamodule.py rename to src/dataset.py index d15384f..7bc81e5 100644 --- a/src/datamodule.py +++ b/src/dataset.py @@ -1,8 +1,5 @@ import numpy as np import torch -import torch.nn as nn -import torchvision -import torchvision.transforms as transforms from torch.utils.data import IterableDataset @@ -537,523 +534,170 @@ def sequence_to_paths_xy(sequence_xy: np.ndarray, p1: int, p2: int) -> np.ndarra return paths_xy -def mnist_template_1d(p: int, label: int, root: str = "data", axis: int = 0): - """ - Return a (p,) 1D template from a random MNIST image by taking a slice or projection. - Values are float32 in [0, 1]. - - Args: - p: dimension of the cyclic group - label: MNIST digit class (0-9) - root: MNIST data directory - axis: 0 for row average, 1 for column average, 2 for diagonal - - Returns: - template: (p,) array - """ - if not (0 <= int(label) <= 9): - raise ValueError("label must be an integer in [0, 9].") - - ds = torchvision.datasets.MNIST( - root=root, train=True, download=True, transform=transforms.ToTensor() - ) - cls_idxs = (ds.targets == int(label)).nonzero(as_tuple=True)[0] - if cls_idxs.numel() == 0: - raise ValueError(f"No samples for label {label}.") - - idx = cls_idxs[torch.randint(len(cls_idxs), (1,)).item()].item() - img, _ = ds[idx] # img: (1, 28, 28) in [0,1] - img = img[0].numpy() # (28, 28) - - # Get 1D signal from 2D image - if axis == 0: - # Average over columns (vertical projection) - signal = img.mean(axis=1) # (28,) - elif axis == 1: - # Average over rows (horizontal projection) - signal = img.mean(axis=0) # (28,) - elif axis == 2: - # Diagonal - signal = np.diag(img) # (28,) - else: - raise ValueError("axis must be 0, 1, or 2") +# --------------------------------------------------------------------------- +# Dataset functions moved from datasets.py +# --------------------------------------------------------------------------- - # Interpolate to desired size p - from scipy.interpolate import interp1d - x_old = np.linspace(0, 1, len(signal)) - x_new = np.linspace(0, 1, p) - f = interp1d(x_old, signal, kind="cubic") - template = f(x_new) +def load_dataset(config): + """Load dataset based on configuration.""" + import src.template as template - return template.astype(np.float32) + tpl = template.template_selector(config) + if config["group_name"] == "cnxcn": + X, Y = cnxcn_dataset(tpl) -def mnist_template_2d(p1: int, p2: int, label: int, root: str = "data"): - """ - Return a (p1, p2) template from a random MNIST image of the given class label (0–9). - Values are float32 in [0, 1]. - """ - if not (0 <= int(label) <= 9): - raise ValueError("label must be an integer in [0, 9].") - - ds = torchvision.datasets.MNIST( - root=root, train=True, download=True, transform=transforms.ToTensor() - ) - cls_idxs = (ds.targets == int(label)).nonzero(as_tuple=True)[0] - if cls_idxs.numel() == 0: - raise ValueError(f"No samples for label {label}.") - - idx = cls_idxs[torch.randint(len(cls_idxs), (1,)).item()].item() - img, _ = ds[idx] # img: (1, 28, 28) in [0,1] - img = nn.functional.interpolate( - img.unsqueeze(0), size=(p1, p2), mode="bilinear", align_corners=False - )[0, 0] - return img.numpy().astype(np.float32) # (p1, p2) - - -### ----- SYNTHETIC TEMPLATES ----- ### - -### 1D Templates ### - - -def generate_fourier_template_1d( - p: int, n_freqs: int, amp_max: float = 100, amp_min: float = 10, seed=None -): - """ - Generate 1D template from random Fourier modes. - - Args: - p: dimension of cyclic group - n_freqs: number of frequency components to include - amp_max: maximum amplitude - amp_min: minimum amplitude - seed: random seed - - Returns: - template: (p,) real-valued array - """ - rng = np.random.default_rng(seed) - spectrum = np.zeros(p, dtype=np.complex128) - - # Select frequencies (skip DC) - available_freqs = list(range(1, p // 2 + 1)) - if len(available_freqs) < n_freqs: - raise ValueError( - f"Only {len(available_freqs)} non-DC frequencies available for p={p}, requested {n_freqs}" - ) - - chosen_freqs = rng.choice( - available_freqs, size=min(n_freqs, len(available_freqs)), replace=False - ) + elif config["group_name"] == "cn": + X, Y = cn_dataset(tpl) - # Amplitudes decreasing with frequency index - amps = np.sqrt(np.linspace(amp_max, amp_min, len(chosen_freqs))) - phases = rng.uniform(0.0, 2 * np.pi, size=len(chosen_freqs)) - - for freq, amp, phi in zip(chosen_freqs, amps, phases): - v = amp * np.exp(1j * phi) - spectrum[freq] = v - spectrum[-freq] = np.conj(v) # Hermitian symmetry for real signal - - template = np.fft.ifft(spectrum).real - template -= template.mean() - s = template.std() - if s > 1e-12: - template /= s - - return template.astype(np.float32) - - -def generate_gaussian_template_1d( - p: int, n_gaussians: int = 3, sigma_range: tuple = (0.5, 2.0), seed=None -): - """ - Generate 1D template as sum of Gaussians. - - Args: - p: dimension of cyclic group - n_gaussians: number of Gaussian bumps - sigma_range: (min_sigma, max_sigma) for Gaussian widths - seed: random seed - - Returns: - template: (p,) real-valued array - """ - rng = np.random.default_rng(seed) - x = np.arange(p) - template = np.zeros(p, dtype=np.float32) - - for _ in range(n_gaussians): - center = rng.uniform(0, p) - sigma = rng.uniform(*sigma_range) - amplitude = rng.uniform(0.5, 1.0) + else: + X, Y = group_dataset(config["group"], tpl) - # Periodic distance - dist = np.minimum(np.abs(x - center), p - np.abs(x - center)) - template += amplitude * np.exp(-(dist**2) / (2 * sigma**2)) + print(f"dataset_fraction: {config['dataset_fraction']}") - template -= template.mean() - s = template.std() - if s > 1e-12: - template /= s + if config["dataset_fraction"] != 1.0: + assert 0 < config["dataset_fraction"] <= 1.0, "fraction must be in (0, 1]" + N = X.shape[0] + n_sample = int(np.ceil(N * config["dataset_fraction"])) + rng = np.random.default_rng(config["seed"]) + indices = rng.choice(N, size=n_sample, replace=False) + X = X[indices] + Y = Y[indices] - return template.astype(np.float32) + return X, Y, tpl -def generate_onehot_template_1d(p: int): - """ - Generate 1D one-hot template for cyclic group C_p. +def group_dataset(group, template): + """Generate a dataset of group elements acting on the template. - This creates a template with a single 1 at position 0 and 0s everywhere else. - When rolled, this one-hot encoding uniquely identifies each group element. + Using the regular representation. - Args: - p: dimension of cyclic group + Parameters + ---------- + group : Group (escnn object) + The group. + template : np.ndarray, shape=[group.order()] + The template to generate the dataset from. - Returns: - template: (p,) array with template[0] = 1, all others = 0 - """ - template = np.zeros(p, dtype=np.float32) - template[0] = 1.0 - return template - - -### 2D Templates ### - - -def gaussian_mixture_template( - p1=20, - p2=20, - n_blobs=8, - frac_broad=0.7, - sigma_broad=(3.5, 6.0), - sigma_narrow=(1.0, 2.0), - amp_broad=1.0, - amp_narrow=0.5, - seed=None, - normalize=True, -): - """ - Build a (p1 x p2) template as a periodic mixture of Gaussians. - Broad Gaussians (low-frequency) get higher weight; a few narrow ones add detail. + Returns + ------- + X : np.ndarray, shape=[group.order()**2, 2, group.order()] + Y : np.ndarray, shape=[group.order()**2, group.order()] """ - rng = np.random.default_rng(seed) - H, W = p1, p2 - Y, X = np.meshgrid(np.arange(H), np.arange(W), indexing="ij") - - k_broad = int(round(n_blobs * frac_broad)) - k_narrow = n_blobs - k_broad - - def add_blobs(k, sigma_range, amp): - out = np.zeros((H, W), dtype=float) - for _ in range(k): - cy, cx = rng.uniform(0, H), rng.uniform(0, W) - sigma = rng.uniform(*sigma_range) - dy = np.minimum(np.abs(Y - cy), H - np.abs(Y - cy)) # periodic (torus) distance - dx = np.minimum(np.abs(X - cx), W - np.abs(X - cx)) - out += amp * np.exp(-(dx**2 + dy**2) / (2.0 * sigma**2)) - return out - - template = ( - add_blobs(k_broad, sigma_broad, amp_broad) # broad, low-freq power - + add_blobs(k_narrow, sigma_narrow, amp_narrow) # a bit of high-freq detail - ) - - if normalize: - template -= template.mean() - s = template.std() - if s > 1e-12: - template /= s - return template.astype(np.float32) - + group_order = group.order() + assert len(template) == group_order, "template must have the same length as the group order" + n_samples = group_order**2 + X = np.zeros((n_samples, 2, group_order)) + Y = np.zeros((n_samples, group_order)) + regular_rep = group.representations["regular"] -def generate_template_unique_freqs(p1, p2, n_freqs, amp_max=100, amp_min=10, seed=None): - """ - Real (p1 x p2) template from n_freqs Fourier modes where: - - No two selected bins are conjugates of each other. - - Self-conjugate singletons are excluded. - - Frequencies are chosen (low→high) by radial order from the rfft-style half-plane. - - Conjugate symmetry: F[ky,kx] = conj( F[-ky mod p1, -kx mod p2] ). - On the rfft half-plane kx ∈ [0, p2//2]: - - If 0 < kx < p2//2, the conjugate sits at kx' = p2 - kx (outside the half-plane) → safe. - - If kx in {0, p2//2 (when even)}, the conjugate keeps the same kx and flips ky → avoid picking both ky and -ky. - - Self-conjugate happens only if kx in {0, p2//2 (when even)} AND ky in {0, p1//2 (when even)} → exclude. - """ - rng = np.random.default_rng(seed) - spectrum = np.zeros((p1, p2), dtype=np.complex128) - - # Helpers - def ky_signed(ky): # map ky ∈ [0..p1-1] to signed range - return ky if ky <= p1 // 2 else ky - p1 - - def is_self_conj(ky, kx): - on_self_kx = (kx == 0) or (p2 % 2 == 0 and kx == p2 // 2) - if not on_self_kx: - return False - s = ky_signed(ky) - return (s == 0) or (p1 % 2 == 0 and abs(s) == p1 // 2) - - # Build candidate list on rfft half-plane, skip DC and self-conjugate singletons - cand = [] - for ky in range(p1): - s = ky_signed(ky) - for kx in range(p2 // 2 + 1): - if ky == 0 and kx == 0: - continue # DC - if is_self_conj(ky, kx): - continue # exclude singletons - r2 = (s**2) + (kx**2) - cand.append((r2, ky, kx)) - cand.sort(key=lambda t: (t[0], abs(ky_signed(t[1])), t[2])) - - # Select without conjugate collisions - chosen = [] - seen_axis_pairs = set() # for kx in {0, mid}, prevent picking both ky and -ky - - mid_kx = p2 // 2 if (p2 % 2 == 0) else None - for _, ky, kx in cand: - if len(chosen) >= n_freqs: - break - if (kx == 0) or (mid_kx is not None and kx == mid_kx): - s = ky_signed(ky) - key = (kx, min(s, -s)) # canonicalize ±ky - if key in seen_axis_pairs: - continue - seen_axis_pairs.add(key) - chosen.append((ky, kx)) + idx = 0 + for g1 in group.elements: + for g2 in group.elements: + g1_rep = regular_rep(g1) + g2_rep = regular_rep(g2) + g12_rep = g1_rep @ g2_rep + + X[idx, 0, :] = g1_rep @ template + X[idx, 1, :] = g2_rep @ template + Y[idx, :] = g12_rep @ template + idx += 1 + + return X, Y + + +def cn_dataset(template): + """Generate a dataset for the cyclic group C_n modular addition operation.""" + group_size = len(template) + X = np.zeros((group_size * group_size, 2, group_size)) + Y = np.zeros((group_size * group_size, group_size)) + + idx = 0 + for a in range(group_size): + for b in range(group_size): + q = (a + b) % group_size + X[idx, 0, :] = np.roll(template, a) + X[idx, 1, :] = np.roll(template, b) + Y[idx, :] = np.roll(template, q) + idx += 1 + + return X, Y + + +def cnxcn_dataset(template): + r"""Generate a dataset for the 2D modular addition operation. + + Parameters + ---------- + template : np.ndarray + A flattened 2D square image of shape (image_length*image_length,). + + Returns + ------- + X : np.ndarray + Input data of shape (image_length^4, 2, image_length*image_length). + Y : np.ndarray + Output data of shape (image_length^4, image_length*image_length). + """ + image_length = int(np.sqrt(len(template))) + X = np.zeros((image_length**4, 2, image_length * image_length)) + Y = np.zeros((image_length**4, image_length * image_length)) + translations = np.zeros((image_length**4, 3, 2), dtype=int) + + idx = 0 + template_2d = template.reshape((image_length, image_length)) + for a_x in range(image_length): + for a_y in range(image_length): + for b_x in range(image_length): + for b_y in range(image_length): + q_x = (a_x + b_x) % image_length + q_y = (a_y + b_y) % image_length + X[idx, 0, :] = np.roll(np.roll(template_2d, a_x, axis=0), a_y, axis=1).flatten() + X[idx, 1, :] = np.roll(np.roll(template_2d, b_x, axis=0), b_y, axis=1).flatten() + Y[idx, :] = np.roll(np.roll(template_2d, q_x, axis=0), q_y, axis=1).flatten() + translations[idx, 0, :] = (a_x, a_y) + translations[idx, 1, :] = (b_x, b_y) + translations[idx, 2, :] = (q_x, q_y) + idx += 1 + + return X, Y + + +def move_dataset_to_device_and_flatten(X, Y, device=None): + """Move dataset tensors to available or specified device. + + Parameters + ---------- + X : np.ndarray + Input data of shape (num_samples, 2, p*p) + Y : np.ndarray + Target data of shape (num_samples, p*p) + device : torch.device, optional + Device to move tensors to. + + Returns + ------- + X : torch.Tensor + Input data tensor on specified device, flattened to (num_samples, 2*p*p) + Y : torch.Tensor + Target data tensor on specified device, flattened to (num_samples, p*p) + """ + num_data_features = len(X[0][0]) + X_flat = X.reshape(X.shape[0], 2 * num_data_features) + Y_flat = Y.reshape(Y.shape[0], num_data_features) + X_tensor = torch.tensor(X_flat, dtype=torch.float32) + Y_tensor = torch.tensor(Y_flat, dtype=torch.float32) + + if device is None: + if torch.cuda.is_available(): + device = torch.device("cuda") + print("GPU is available. Using CUDA.") else: - # 0 < kx < mid_kx (or no mid): conjugate lives outside half-plane → always safe - chosen.append((ky, kx)) - - if len(chosen) < n_freqs: - raise ValueError( - f"Could only find {len(chosen)} unique non-conjugate bins; " - f"requested {n_freqs}. Increase grid size or reduce n_freqs." - ) - - # Amplitudes + random phases, then place each bin + its conjugate - amps = np.sqrt(np.linspace(amp_max, amp_min, n_freqs, dtype=float)) - phases = rng.uniform(0.0, 2 * np.pi, size=n_freqs) - - for (ky, kx), a, phi in zip(chosen, amps, phases): - kyc, kxc = (-ky) % p1, (-kx) % p2 - v = a * np.exp(1j * phi) - spectrum[ky, kx] += v - spectrum[kyc, kxc] += np.conj(v) - - template = np.fft.ifft2(spectrum).real - template -= template.mean() - s = template.std() - if s > 1e-12: - template /= s - return template.astype(np.float32) - - -def generate_fixed_template_2d(p1: int, p2: int) -> np.ndarray: - """ - Generate 2D template array from Fourier spectrum. - - Args: - p1: height dimension - p2: width dimension - - Returns: - template: (p1, p2) real-valued array - """ - # Generate template array from 2D Fourier spectrum - spectrum = np.zeros((p1, p2), dtype=complex) - - assert p1 > 5 and p2 > 5, "p1 and p2 must be greater than 5" - - # Set 2D frequency components with specific amplitudes - # Format: spectrum[kx, ky] where kx is "vertical freq", ky is "horizontal freq" - - # Axis-aligned frequencies - spectrum[1, 0] = 10.0 # vertical frequency 1 - spectrum[-1, 0] = 10.0 # conjugate - # spectrum[0, 1] = 10.0 # horizontal frequency 1 - # spectrum[0, -1] = 10.0 # conjugate - - # Higher frequency components - # spectrum[3, 0] = 7.5 - # spectrum[-3, 0] = 7.5 - spectrum[0, 3] = 7.5 - spectrum[0, -3] = 7.5 - - # Diagonal/mixed frequencies - spectrum[2, 1] = 5.0 - spectrum[-2, -1] = 5.0 # conjugate - # spectrum[1, 2] = 5.0 - # spectrum[-1, -2] = 5.0 # conjugate - - # Generate signal from spectrum - template = np.fft.ifft2(spectrum).real - - return template - - -# Spherically Symmetric Templates - - -def _fft_indices(n): - """ - Return integer-like frequency indices aligned with numpy's FFT layout. - Example: n=8 -> [0,1,2,3,4,-3,-2,-1] - """ - k = np.fft.fftfreq(n) * n - return k.astype(int) - - -def generate_hexagon_tie_template_2d(p1: int, p2: int, k0: float = 6.0, amp: float = 1.0): - """ - Real template whose 2D Fourier spectrum has equal maxima at six directions - (0°, 60°, 120°, 180°, 240°, 300°) with radius ~ k0 (in FFT index units). - - Args: - p1, p2: spatial dims (height, width). Require > 5 recommended. - k0: desired radius (index units). Not necessarily integer; we round. - amp: amplitude per spike (before conjugate pairing) - - Returns: - template: (p1, p2) real-valued array - """ - assert p1 > 5 and p2 > 5, "p1 and p2 must be > 5" - spec = np.zeros((p1, p2), dtype=np.complex128) - - # Six target angles for a hexagon - thetas = np.arange(6) * (np.pi / 3.0) - - # FFT index grids - Kx = _fft_indices(p1) - Ky = _fft_indices(p2) - - # Map from (kx,ky) in index space to array indices (wrapped) - def put(kx, ky, val): - spec[int(kx) % p1, int(ky) % p2] += val - - used = set() - for th in thetas: - # Target continuous coordinates at radius k0 - kx_f = k0 * np.cos(th) - ky_f = k0 * np.sin(th) - # Round to nearest integer grid point - kx = int(np.round(kx_f)) - ky = int(np.round(ky_f)) - # Avoid (0,0) and duplicates - if (kx, ky) == (0, 0): - # nudge radius by 1 if rounding hit DC - if abs(np.cos(th)) > abs(np.sin(th)): - kx = 1 if kx >= 0 else -1 - else: - ky = 1 if ky >= 0 else -1 - if (kx, ky) in used: - continue - used.add((kx, ky)) - used.add((-kx, -ky)) - - # Place equal-amplitude spikes with Hermitian symmetry - put(kx, ky, amp) # +k - put(-kx, -ky, np.conjugate(amp)) # -k (conjugate) - - # Remove DC (optional) to avoid mean offset - spec[0, 0] = 0.0 - - # Real template - x = np.fft.ifft2(spec).real - return x - - -def generate_ring_isotropic_template_2d( - p1: int, p2: int, r0: float = 6.0, sigma: float = 0.5, total_power: float = 1.0 -): - """ - Real template with a narrow, isotropic ring in the 2D spectrum: |X(k)| ≈ exp(- (||k||-r0)^2 / (2 sigma^2)). - This produces a spherical (circular) symmetry -> orientation tie across the ring. - - Args: - p1, p2: spatial dims - r0: target radius (index units) - sigma: radial width of the ring - total_power: scales overall energy (roughly) - - Returns: - template: (p1, p2) real-valued array - """ - assert p1 > 5 and p2 > 5, "p1 and p2 must be > 5" - - # Build index grids in FFT layout - kx = _fft_indices(p1)[:, None] # (p1,1) - ky = _fft_indices(p2)[None, :] # (1,p2) - R = np.sqrt(kx**2 + ky**2) - - # Radial Gaussian ring (real, even -> already Hermitian when phases are 0) - mag = np.exp(-0.5 * ((R - r0) / max(sigma, 1e-6)) ** 2) - - # Optional: zero DC - mag[0, 0] = 0.0 - - # Normalize to desired total power (approximate; ifft2 has 1/(p1*p2) factor) - power = np.sum(mag**2) - if power > 0: - mag *= np.sqrt(total_power / power) - - # Real, symmetric spectrum (phase = 0 everywhere) - spec = mag.astype(np.complex128) - - x = np.fft.ifft2(spec).real - return x + device = torch.device("cpu") + print("GPU is not available. Using CPU.") + X_tensor = X_tensor.to(device) + Y_tensor = Y_tensor.to(device) -def generate_gaussian_template_2d( - p1: int, - p2: int, - center: tuple[float, float] = None, - sigma: float = 2.0, - k_freqs: int = None, -) -> np.ndarray: - """ - Generate 2D template with a single Gaussian, optionally band-limited to top-k frequencies. - Args: - p1: height dimension - p2: width dimension - center: (cx, cy) center position, defaults to center of grid - sigma: standard deviation of Gaussian - k_freqs: if not None, keep only the top k frequencies by power (band-limit) - Returns: - template: (p1, p2) real-valued array - """ - if center is None: - center = (p1 / 2, p2 / 2) - cx, cy = center - # Create coordinate grids - x = np.arange(p1) - y = np.arange(p2) - X, Y = np.meshgrid(x, y, indexing="ij") - # Compute Gaussian - template = np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * sigma**2)) - # If k_freqs specified, band-limit to top-k frequencies - if k_freqs is not None: - # Take DFT - spectrum = np.fft.fft2(template) - # Compute power for each frequency - power = np.abs(spectrum) ** 2 - power_flat = power.flatten() - # Get indices of all frequencies - kx_indices = np.arange(p1) - ky_indices = np.arange(p2) - KX, KY = np.meshgrid(kx_indices, ky_indices, indexing="ij") - all_indices = list(zip(KX.flatten(), KY.flatten())) - # Sort by power and select top-k - sorted_idx = np.argsort(-power_flat) - top_k_idx = sorted_idx[:k_freqs] - top_k_freqs = set([all_indices[i] for i in top_k_idx]) - # Create mask: 1 for top-k frequencies, 0 for others - mask = np.zeros((p1, p2), dtype=complex) - for kx, ky in top_k_freqs: - mask[kx, ky] = 1.0 - # Apply mask and take IDFT - spectrum_masked = spectrum * mask - template = np.fft.ifft2(spectrum_masked).real - return template + return X_tensor, Y_tensor, device diff --git a/src/datasets.py b/src/datasets.py deleted file mode 100644 index a98df0d..0000000 --- a/src/datasets.py +++ /dev/null @@ -1,213 +0,0 @@ -import numpy as np -import torch - -import src.templates as templates - - -def load_dataset(config): - """Load dataset based on configuration.""" - - template = template_selector(config) - - if config["group_name"] == "cnxcn": - X, Y = cnxcn_dataset(template) - - elif config["group_name"] == "cn": - X, Y = cn_dataset(template) - - else: - X, Y = group_dataset(config["group"], template) - - print(f"dataset_fraction: {config['dataset_fraction']}") - - if config["dataset_fraction"] != 1.0: - assert 0 < config["dataset_fraction"] <= 1.0, "fraction must be in (0, 1]" - # Sample a subset of the dataset according to the specified fraction - N = X.shape[0] - n_sample = int(np.ceil(N * config["dataset_fraction"])) - rng = np.random.default_rng(config["seed"]) - indices = rng.choice(N, size=n_sample, replace=False) # indices of the sampled subset - X = X[indices] - Y = Y[indices] - - return X, Y, template - - -def template_selector(config): - """Select template based on configuration.""" - if config["template_type"] == "irrep_construction": - if config["group_name"] == "cnxcn": - template = templates.fixed_cnxcn_template( - config["image_length"], config["fourier_coef_diag_values"] - ) - elif config["group_name"] == "cn": - template = templates.fixed_cn_template( - config["group_n"], config["fourier_coef_diag_values"] - ) - else: - template = templates.fixed_group_template( - config["group"], config["fourier_coef_diag_values"] - ) - elif config["template_type"] == "one_hot": - template = templates.one_hot(config["group_size"]) - else: - raise ValueError(f"Unknown template type: {config['template_type']}") - return template - - -def group_dataset(group, template): - """Generate a dataset of group elements acting on the template. - - Using the regular representation. - - Parameters - ---------- - group : Group (escnn object) - The group. - template : np.ndarray, shape=[group.order()] - The template to generate the dataset from. - - Returns - ------- - X : np.ndarray, shape=[group.order()**2, 2, group.order()] - The dataset of group elements acting on the template. - Y : np.ndarray, shape=[group.order()**2, group.order()] - The dataset of group elements acting on the template. - """ - - # Initialize data arrays - group_order = group.order() - assert len(template) == group_order, "template must have the same length as the group order" - n_samples = group_order**2 - X = np.zeros((n_samples, 2, group_order)) - Y = np.zeros((n_samples, group_order)) - regular_rep = group.representations["regular"] - - idx = 0 - for g1 in group.elements: - for g2 in group.elements: - g1_rep = regular_rep(g1) - g2_rep = regular_rep(g2) - g12_rep = g1_rep @ g2_rep - - X[idx, 0, :] = g1_rep @ template - X[idx, 1, :] = g2_rep @ template - Y[idx, :] = g12_rep @ template - idx += 1 - - return X, Y - - -def cn_dataset(template): - """Generate a dataset for the cyclic group C_n modular addition operation.""" - group_size = len(template) - X = np.zeros((group_size * group_size, 2, group_size)) - Y = np.zeros((group_size * group_size, group_size)) - - # Generate the dataset - idx = 0 - for a in range(group_size): - for b in range(group_size): - q = (a + b) % group_size # a + b mod p - X[idx, 0, :] = np.roll(template, a) - X[idx, 1, :] = np.roll(template, b) - Y[idx, :] = np.roll(template, q) - idx += 1 - - return X, Y - - -def cnxcn_dataset(template): - r"""Generate a dataset for the 2D modular addition operation. - - General idea: We are generating a dataset where each sample consists of - two inputs (a*template and b*template) and an output (a*b)*template, - where $(a, b) \in C_n x C_n$, where $C_n x C_n$ is the product of cn - groups. The template is a flattened 2D array representing the modular addition - operation in a 2D space. - - Each element $X_i$ will contain the template with a different $a_i$, $b_i$, and - in fact $X$ contains the template at all possible $a$, $b$ shifts. - The output $Y_i$ will contain the template shifted by $a_i*b_i$. - * refers to the composition of two group actions (but by an abuse of notation, - also refers to group action on the template.) - - Parameters - ---------- - template : np.ndarray - A flattened 2D square image of shape (image_length*image_length,). - - Returns - ------- - X : np.ndarray - Input data of shape (image_length^4, 2, image_length*image_length). - 2 inputs (a and b), each with shape (image_length*image_length,). - is the total number of combinations of shifted a's and b's. - Y : np.ndarray - Output data of shape (image_length^4, image_length*image_length), where each - sample is the result of the modular addition. - """ - image_length = int(np.sqrt(len(template))) - # Initialize data arrays - X = np.zeros((image_length**4, 2, image_length * image_length)) - Y = np.zeros((image_length**4, image_length * image_length)) - translations = np.zeros((image_length**4, 3, 2), dtype=int) - - # Generate the dataset - idx = 0 - template_2d = template.reshape((image_length, image_length)) - for a_x in range(image_length): - for a_y in range(image_length): - for b_x in range(image_length): - for b_y in range(image_length): - q_x = (a_x + b_x) % image_length - q_y = (a_y + b_y) % image_length - X[idx, 0, :] = np.roll(np.roll(template_2d, a_x, axis=0), a_y, axis=1).flatten() - X[idx, 1, :] = np.roll(np.roll(template_2d, b_x, axis=0), b_y, axis=1).flatten() - Y[idx, :] = np.roll(np.roll(template_2d, q_x, axis=0), q_y, axis=1).flatten() - translations[idx, 0, :] = (a_x, a_y) - translations[idx, 1, :] = (b_x, b_y) - translations[idx, 2, :] = (q_x, q_y) - idx += 1 - - return X, Y - - -def move_dataset_to_device_and_flatten(X, Y, device=None): - """Move dataset tensors to available or specified device. - - Parameters - ---------- - X : np.ndarray - Input data of shape (num_samples, 2, p*p) - Y : np.ndarray - Target data of shape (num_samples, p*p) - device : torch.device, optional - Device to move tensors to. If None, automatically choose GPU if available. - - Returns - ------- - X : torch.Tensor - Input data tensor on specified device, flattened to (num_samples, 2*p*p) - Y : torch.Tensor - Target data tensor on specified device, flattened to (num_samples, p*p) - """ - # Reshape X to (num_samples, 2*num_data_features), where num_data_features is inferred from len(X[0][0]) - num_data_features = len(X[0][0]) - X_flat = X.reshape(X.shape[0], 2 * num_data_features) - Y_flat = Y.reshape(Y.shape[0], num_data_features) - X_tensor = torch.tensor(X_flat, dtype=torch.float32) - Y_tensor = torch.tensor(Y_flat, dtype=torch.float32) - - if device is None: - if torch.cuda.is_available(): - device = torch.device("cuda") - print("GPU is available. Using CUDA.") - else: - device = torch.device("cpu") - print("GPU is not available. Using CPU.") - - X_tensor = X_tensor.to(device) - Y_tensor = Y_tensor.to(device) - - return X_tensor, Y_tensor, device diff --git a/src/fourier.py b/src/fourier.py new file mode 100644 index 0000000..3bf199e --- /dev/null +++ b/src/fourier.py @@ -0,0 +1,63 @@ +import numpy as np +from escnn.group import * + + +def group_fourier(group, template): + """Compute the group Fourier transform of the template. + + For each irrep rho, compute the Fourier coefficient: + hat x [rho] = sum_{g in G} x[g] * rho(g).conj().T + + Parameters + ---------- + group : Group (escnn object) + The group. + template : np.ndarray, shape=[group.order()] + The template to compute the Fourier transform of. + + Returns + ------- + fourier_coefs : list of np.ndarray, each of shape=[irrep.size, irrep.size] + A list of (matrix) Fourier coefficients of template at each irrep. + """ + irreps = group.irreps() + fourier_coefs = [] + for irrep in irreps: + coef = sum([template[i_g] * irrep(g).conj().T for i_g, g in enumerate(group.elements)]) + fourier_coefs.append(coef) + return fourier_coefs + + +def group_fourier_inverse(group, fourier_coefs): + """Compute the inverse group Fourier transform. + + Using the formula: + x(g) = 1/|G| * sum_{rho in irreps} dim(rho) * Tr(rho(g) * hat x[rho]) + + Parameters + ---------- + group : Group (escnn object) + The group. + fourier_coefs : list of np.ndarray, each of shape=[irrep.size, irrep.size] + The (matrix) Fourier coefficients of template at each irrep. + + Returns + ------- + signal : np.ndarray, shape=[group.order()] + The inverse Fourier transform: a signal over the group. + """ + irreps = group.irreps() + + def _inverse_at_element(g): + return ( + 1 + / group.order() + * sum( + [ + irrep.size * np.trace(irrep(g) @ fourier_coefs[i]) + for i, irrep in enumerate(irreps) + ] + ) + ) + + return np.array([_inverse_at_element(g) for g in group.elements]) diff --git a/src/group_fourier_transform.py b/src/group_fourier_transform.py deleted file mode 100644 index 9021ff6..0000000 --- a/src/group_fourier_transform.py +++ /dev/null @@ -1,138 +0,0 @@ -import numpy as np -from escnn.group import * - - -def compute_group_fourier_coef(group, template, irrep): - """Compute the Fourier coefficient of template x at irrep rho. - - hat x [rho] = sum_{g in G} x[g] * rho(g).conj().T - - Formula from the Group-AGF paper. - - Parameters - ---------- - group : Group - The group (escnn object) - template : np.ndarray, shape=[group.order()] - The template to compute the Fourier coefficient of. - irrep : IrreducibleRepresentation - The irrep (escnn object). - - Returns - ------- - _ : np.ndarray, shape=[irrep.size, irrep.size] - The (matrix) Fourier coefficient of template x at irrep rho. - """ - return sum([template[i_g] * irrep(g).conj().T for i_g, g in enumerate(group.elements)]) - - -def compute_group_fourier_transform(group, template): - """Compute the group Fourier transform of the template. - - Parameters - ---------- - group : Group - The group (escnn object) - template : np.ndarray, shape=[group.order()] - The template to compute the Fourier transform of. - - Returns - ------- - _: list of np.ndarray, each of shape=[irrep.size, irrep.size] - A list of (matrix) Fourier coefficients of template at each irrep. - Since each irrep has a different size (dimension), the (matrix) Fourier - coefficients have different shapes: the list cannot be concatenated - into a single array. - """ - irreps = group.irreps() - fourier_coefs = [] - for irrep in irreps: - fourier_coef = compute_group_fourier_coef(group, template, irrep) - fourier_coefs.append(fourier_coef) - return fourier_coefs - - -def compute_group_inverse_fourier_element(group, fourier_transform, g): - """Compute the inverse Fourier transform at element g. - - Using the formula (Wikipedia): - x(g) = 1/|G| * sum_{rho in irreps} dim(rho) * Tr(rho(g) * hat x[rho]) - - Parameters - ---------- - group : Group (escnn object) - The group. - fourier_transform : list of np.ndarray, each of shape=[irrep.size, irrep.size] - The (matrix) Fourier coefficients of template at each irrep. - g : GroupElement (escnn object) - The element of the group to compute the inverse Fourier transform at. - - Returns - ------- - _ : np.ndarray, shape=[group.order()] - The inverse Fourier transform at element g. - """ - irreps = group.irreps() - - inverse_fourier_element = ( - 1 - / group.order() - * sum( - [ - irrep.size * np.trace(irrep(g) @ fourier_transform[i]) - for i, irrep in enumerate(irreps) - ] - ) - ) - - return inverse_fourier_element - - -def compute_group_inverse_fourier_transform(group, fourier_coefs): - """Compute the inverse Fourier transform. - - Parameters - ---------- - group : Group (escnn object) - The group. - fourier_coefs : list of np.ndarray, each of shape=[irrep.size, irrep.size] - The (matrix) Fourier coefficients of template at each irrep. - - Returns - ------- - _ : np.ndarray, shape=[group.order()] - The inverse Fourier transform: a signal over the group. - """ - return np.array( - [compute_group_inverse_fourier_element(group, fourier_coefs, g) for g in group.elements] - ) - - -def group_power_spectrum(group, template): - """Compute the (group) power spectrum of the template. - - For each irrep rho, the power is given by: - ||hat x(rho)||_rho = dim(rho) * Tr(hat x(rho)^dagger * hat x(rho)) - where hat x(rho) is the (matrix) Fourier coefficient of template x at irrep rho. - - Parameters - ---------- - group : Group (escnn object) - The group. - template : np.ndarray, shape=[group.order()] - The template to compute the power spectrum of. - - Returns - ------- - _ : np.ndarray, shape=[len(group.irreps())] - The power spectrum of the template. - """ - - irreps = group.irreps() - - power_spectrum = np.zeros(len(irreps)) - for i, irrep in enumerate(irreps): - fourier_coef = compute_group_fourier_coef(group, template, irrep) - power_spectrum[i] = irrep.size * np.trace(fourier_coef.conj().T @ fourier_coef) - power_spectrum = power_spectrum / group.order() - return np.array(power_spectrum) diff --git a/src/main.py b/src/main.py index 0d9aaa3..a03a75f 100644 --- a/src/main.py +++ b/src/main.py @@ -13,25 +13,13 @@ from torch import nn, optim from torch.utils.data import DataLoader -from src.datamodule import ( - generate_fourier_template_1d, - generate_gaussian_template_1d, - generate_onehot_template_1d, - generate_template_unique_freqs, - mnist_template_1d, - mnist_template_2d, -) -from src.model import QuadraticRNN, SequentialMLP, TwoLayerNet -from src.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD -from src.utils import ( - plot_2d_signal, - plot_model_predictions_over_time, - plot_model_predictions_over_time_1d, - plot_prediction_power_spectrum_over_time_1d, - plot_training_loss_with_theory, - plot_wmix_frequency_structure, - topk_template_freqs, -) +import src.dataset as dataset +import src.fourier as fourier +import src.model as model +import src.optimizer as optimizer +import src.power as power +import src.template as template +import src.viz as viz matplotlib.rcParams["pdf.fonttype"] = 42 # TrueType fonts for PDF viewer compatibility matplotlib.rcParams["ps.fonttype"] = 42 @@ -172,9 +160,7 @@ def produce_plots_2d( ### ----- GENERATE EVALUATION DATA ----- ### print("Generating evaluation data for visualization...") - from src.datamodule import build_modular_addition_sequence_dataset_2d - - X_seq_2d, Y_seq_2d, _ = build_modular_addition_sequence_dataset_2d( + X_seq_2d, Y_seq_2d, _ = dataset.build_modular_addition_sequence_dataset_2d( config["data"]["p1"], config["data"]["p2"], template_2d, @@ -201,7 +187,7 @@ def produce_plots_2d( print("\nPlotting training loss...") # Plot 1: Loss vs Steps/Epochs - plot_training_loss_with_theory( + viz.plot_train_loss_with_theory( loss_history=train_loss_hist, template_2d=template_2d, p1=config["data"]["p1"], @@ -213,7 +199,7 @@ def produce_plots_2d( ) # Plot 2: Loss vs Samples Seen - plot_training_loss_with_theory( + viz.plot_train_loss_with_theory( loss_history=train_loss_hist, template_2d=template_2d, p1=config["data"]["p1"], @@ -225,7 +211,7 @@ def produce_plots_2d( ) # Plot 3: Loss vs Fraction of Space - plot_training_loss_with_theory( + viz.plot_train_loss_with_theory( loss_history=train_loss_hist, template_2d=template_2d, p1=config["data"]["p1"], @@ -238,7 +224,7 @@ def produce_plots_2d( ### ----- PLOT MODEL PREDICTIONS ----- ### print("Plotting model predictions over time...") - plot_model_predictions_over_time( + viz.plot_predictions_2d( model, param_hist, X_seq_2d_t, @@ -252,14 +238,14 @@ def produce_plots_2d( ### ----- PLOT FOURIER MODES REFERENCE ----- ### print("Creating Fourier modes reference...") - tracked_freqs = topk_template_freqs(template_2d, K=10) + tracked_freqs = power.topk_template_freqs(template_2d, K=10) colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) ### ----- PLOT W_MIX FREQUENCY STRUCTURE (QuadraticRNN only) ----- ### model_type = config["model"]["model_type"] if model_type == "QuadraticRNN": print("Visualizing W_mix frequency structure...") - plot_wmix_frequency_structure( + viz.plot_wmix_structure( param_hist, tracked_freqs, colors, @@ -331,9 +317,7 @@ def produce_plots_1d( ### ----- GENERATE EVALUATION DATA ----- ### print("Generating evaluation data for visualization...") - from src.datamodule import build_modular_addition_sequence_dataset_1d - - X_seq_1d, Y_seq_1d, _ = build_modular_addition_sequence_dataset_1d( + X_seq_1d, Y_seq_1d, _ = dataset.build_modular_addition_sequence_dataset_1d( config["data"]["p"], template_1d, config["data"]["k"], @@ -386,7 +370,7 @@ def produce_plots_1d( ### ----- PLOT MODEL PREDICTIONS ----- ### print("Plotting model predictions over time...") - plot_model_predictions_over_time_1d( + viz.plot_predictions_1d( model, param_hist, X_seq_1d_t, @@ -399,7 +383,7 @@ def produce_plots_1d( ### ----- PLOT POWER SPECTRUM ANALYSIS ----- ### print("Analyzing power spectrum of predictions over training...") - plot_prediction_power_spectrum_over_time_1d( + viz.plot_power_1d( model, param_hist, X_seq_1d_t, @@ -418,224 +402,6 @@ def produce_plots_1d( print("\n✓ All 1D plots generated successfully!") -def plot_model_predictions_over_time_group( - model, - param_hist, - X_eval, - Y_eval, - group_order: int, - checkpoint_indices: list, - save_path: str = None, - num_samples: int = 5, - group_label: str = "Group", -): - """ - Plot model predictions vs targets at different training checkpoints. - - Args: - model: Trained model - param_hist: List of parameter snapshots - X_eval: Input evaluation tensor (N, k, group_order) - Y_eval: Target evaluation tensor (N, group_order) - group_order: Order of the group - checkpoint_indices: Indices into param_hist to visualize - save_path: Path to save the plot - num_samples: Number of samples to show - group_label: Human-readable label for the group (used in plot title) - """ - n_checkpoints = len(checkpoint_indices) - - fig, axes = plt.subplots( - num_samples, n_checkpoints, figsize=(4 * n_checkpoints, 3 * num_samples) - ) - if num_samples == 1: - axes = axes.reshape(1, -1) - if n_checkpoints == 1: - axes = axes.reshape(-1, 1) - - # Select random sample indices - sample_indices = np.random.choice( - len(X_eval), size=min(num_samples, len(X_eval)), replace=False - ) - - for col, ckpt_idx in enumerate(checkpoint_indices): - # Load parameters for this checkpoint - model.load_state_dict(param_hist[ckpt_idx]) - model.eval() - - with torch.no_grad(): - outputs = model(X_eval[sample_indices]) - outputs_np = outputs.cpu().numpy() - targets_np = Y_eval[sample_indices].cpu().numpy() - - for row, (output, target) in enumerate(zip(outputs_np, targets_np)): - ax = axes[row, col] - x_axis = np.arange(group_order) - - ax.bar(x_axis - 0.15, target, width=0.3, label="Target", alpha=0.7, color="#2ecc71") - ax.bar(x_axis + 0.15, output, width=0.3, label="Output", alpha=0.7, color="#e74c3c") - - if row == 0: - ax.set_title(f"Checkpoint {ckpt_idx}") - if col == 0: - ax.set_ylabel(f"Sample {sample_indices[row]}") - if row == num_samples - 1: - ax.set_xlabel("Group element") - if row == 0 and col == n_checkpoints - 1: - ax.legend(loc="upper right", fontsize=8) - - ax.set_xticks(x_axis) - ax.grid(True, alpha=0.3) - - plt.suptitle(f"{group_label} Predictions vs Targets Over Training", fontsize=14) - plt.tight_layout() - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - plt.close() - - -def plot_power_spectrum_over_time_group( - model, - param_hist, - param_save_indices, - X_eval, - template: np.ndarray, - group, - k: int, - optimizer: str, - init_scale: float, - save_path: str = None, - group_label: str = "Group", -): - """ - Plot power spectrum of model outputs vs template power spectrum over training. - - Uses GroupPower from src/power.py for template power and model_power_over_time - for model output power over training checkpoints. - - Args: - model: Trained model - param_hist: List of parameter snapshots - param_save_indices: List mapping param_hist index to epoch number - X_eval: Input evaluation tensor - template: Template array (group_order,) - group: escnn group object - k: Sequence length - optimizer: Optimizer name (e.g., 'per_neuron', 'adam') - init_scale: Initialization scale - save_path: Path to save the plot - group_label: Human-readable label for the group (used in plot titles) - """ - from src.power import GroupPower, model_power_over_time - - group_name = "group" # generic group for model_power_over_time dispatch - irreps = group.irreps() - n_irreps = len(irreps) - - # Compute template power spectrum using GroupPower - template_power_obj = GroupPower(template, group=group) - template_power = template_power_obj.power - - print(f" Template power spectrum: {template_power}") - print(" (These are dim^2 * diag_value^2 / |G| for each irrep)") - - # Compute model output power over training using model_power_over_time - model_powers, steps = model_power_over_time(group_name, model, param_hist, X_eval, group=group) - # Map step indices to epoch numbers - epoch_numbers = [param_save_indices[min(s, len(param_save_indices) - 1)] for s in steps] - - # Create 3 subplots: linear, log-x, log-log - fig, axes = plt.subplots(1, 3, figsize=(18, 5)) - - top_k = min(5, n_irreps) - top_irrep_indices = np.argsort(template_power)[::-1][:top_k] - - colors_line = plt.cm.tab10(np.linspace(0, 1, top_k)) - - # Filter out zero epochs for log scales - valid_mask = np.array(epoch_numbers) > 0 - valid_epochs = np.array(epoch_numbers)[valid_mask] - valid_model_powers = model_powers[valid_mask, :] - - # Plot 1: Linear scales - ax = axes[0] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = model_powers[:, irrep_idx] - ax.plot( - epoch_numbers, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - ax.set_xlabel("Epoch") - ax.set_ylabel("Power") - ax.set_title("Linear Scales", fontsize=12) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - - # Plot 2: Log x-axis only - ax = axes[1] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - ax.plot( - valid_epochs, - power_values, - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - ax.set_xscale("log") - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power") - ax.set_title("Log X-axis", fontsize=12) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - - # Plot 3: Log-log scales - ax = axes[2] - for i, irrep_idx in enumerate(top_irrep_indices): - power_values = valid_model_powers[:, irrep_idx] - # Filter out zero powers for log scale - power_mask = power_values > 0 - if np.any(power_mask): - ax.plot( - valid_epochs[power_mask], - power_values[power_mask], - "-", - lw=2, - color=colors_line[i], - label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", - ) - if template_power[irrep_idx] > 0: - ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) - ax.set_xscale("log") - ax.set_yscale("log") - ax.set_xlabel("Epoch (log scale)") - ax.set_ylabel("Power (log scale)") - ax.set_title("Log-Log Scales", fontsize=12) - ax.legend(loc="upper left", fontsize=7) - ax.grid(True, alpha=0.3) - - # Overall title - fig.suptitle( - f"{group_label} Power Evolution Over Training (k={k}, {optimizer}, init={init_scale:.0e})", - fontsize=14, - fontweight="bold", - ) - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - plt.close() - - def produce_plots_group( run_dir: Path, config: dict, @@ -719,10 +485,10 @@ def produce_plots_group( if model_type == "TwoLayerNet": # TwoLayerNet expects flattened binary pair input: (N, 2*group_size) - from src.datasets import group_dataset, move_dataset_to_device_and_flatten - - X_raw, Y_raw = group_dataset(group, template) - X_eval_t, Y_eval_t, device = move_dataset_to_device_and_flatten(X_raw, Y_raw, device=device) + X_raw, Y_raw = dataset.group_dataset(group, template) + X_eval_t, Y_eval_t, device = dataset.move_dataset_to_device_and_flatten( + X_raw, Y_raw, device=device + ) # Optionally subsample for visualization n_eval = min(len(X_eval_t), 1000) if n_eval < len(X_eval_t): @@ -731,9 +497,7 @@ def produce_plots_group( Y_eval_t = Y_eval_t[indices] else: # Sequence models use the generic sequence dataset - from src.datamodule import build_modular_addition_sequence_dataset_generic - - X_eval, Y_eval, _ = build_modular_addition_sequence_dataset_generic( + X_eval, Y_eval, _ = dataset.build_modular_addition_sequence_dataset_generic( template, k, group=group, @@ -782,7 +546,7 @@ def produce_plots_group( ### ----- PLOT MODEL PREDICTIONS OVER TIME ----- ### print("\nPlotting model predictions over time...") - plot_model_predictions_over_time_group( + viz.plot_predictions_group( model=model, param_hist=param_hist, X_eval=X_eval_t, @@ -796,9 +560,9 @@ def produce_plots_group( ### ----- PLOT POWER SPECTRUM OVER TIME ----- ### print("\nPlotting power spectrum over time...") - optimizer = config["training"]["optimizer"] + optimizer_name = config["training"]["optimizer"] init_scale = config["model"]["init_scale"] - plot_power_spectrum_over_time_group( + viz.plot_power_group( model=model, param_hist=param_hist, param_save_indices=param_save_indices, @@ -806,7 +570,7 @@ def produce_plots_group( template=template, group=group, k=k, - optimizer=optimizer, + optimizer=optimizer_name, init_scale=init_scale, save_path=os.path.join(run_dir, "power_spectrum_analysis.pdf"), group_label=group_label, @@ -854,23 +618,19 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: p_flat = p if template_type == "mnist": - template_1d = mnist_template_1d(p, config["data"]["mnist_label"], root="data") + template_1d = template.mnist_1d(p, config["data"]["mnist_label"], root="data") elif template_type == "fourier": n_freqs = config["data"]["n_freqs"] - template_1d = generate_fourier_template_1d( - p, n_freqs=n_freqs, seed=config["data"]["seed"] - ) + template_1d = template.fourier_1d(p, n_freqs=n_freqs, seed=config["data"]["seed"]) elif template_type == "gaussian": - template_1d = generate_gaussian_template_1d( - p, n_gaussians=3, seed=config["data"]["seed"] - ) + template_1d = template.gaussian_1d(p, n_gaussians=3, seed=config["data"]["seed"]) elif template_type == "onehot": - template_1d = generate_onehot_template_1d(p) + template_1d = template.onehot_1d(p) else: raise ValueError(f"Unknown template_type: {template_type}") template_1d = template_1d - np.mean(template_1d) - template = template_1d # For consistency in code below + tpl = template_1d # For consistency in code below # Visualize 1D template print("Visualizing template...") @@ -890,28 +650,24 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: p_flat = p1 * p2 if template_type == "mnist": - template_2d = mnist_template_2d(p1, p2, config["data"]["mnist_label"], root="data") + template_2d = template.mnist_2d(p1, p2, config["data"]["mnist_label"], root="data") elif template_type == "fourier": n_freqs = config["data"]["n_freqs"] - template_2d = generate_template_unique_freqs( + template_2d = template.unique_freqs_2d( p1, p2, n_freqs=n_freqs, seed=config["data"]["seed"] ) else: raise ValueError(f"Unknown template_type for cnxcn: {template_type}") template_2d = template_2d - np.mean(template_2d) - template = template_2d # For consistency in code below + tpl = template_2d # For consistency in code below # Visualize 2D template print("Visualizing template...") - fig, ax = plot_2d_signal(template_2d, title="Template", cmap="gray") + fig, ax = viz.plot_signal_2d(template_2d, title="Template", cmap="gray") fig.savefig(os.path.join(run_dir, "template.pdf"), bbox_inches="tight", dpi=150) print(" ✓ Saved template") elif group_name in ("dihedral", "octahedral", "A5"): - from src.group_fourier_transform import ( - compute_group_inverse_fourier_transform, - ) - # Construct the escnn group object if group_name == "dihedral": from escnn.group import DihedralGroup @@ -938,9 +694,9 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: # Generate template if template_type == "onehot": - template = np.zeros(group_order, dtype=np.float32) - template[1] = 10.0 - template = template - np.mean(template) + tpl = np.zeros(group_order, dtype=np.float32) + tpl[1] = 10.0 + tpl = tpl - np.mean(tpl) print("Template type: onehot") elif template_type == "custom_fourier": @@ -972,21 +728,21 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) spectrum.append(mat) - template = compute_group_inverse_fourier_transform(group, spectrum) - template = template - np.mean(template) - template = template.astype(np.float32) + tpl = fourier.group_fourier_inverse(group, spectrum) + tpl = tpl - np.mean(tpl) + tpl = tpl.astype(np.float32) else: raise ValueError( f"Unknown template_type for {group_name}: {template_type}. " "Must be 'onehot' or 'custom_fourier'" ) - print(f"Template shape: {template.shape}") + print(f"Template shape: {tpl.shape}") # Visualize template print("Visualizing template...") fig, ax = plt.subplots(figsize=(max(8, group_order // 5), 4)) - ax.bar(range(group_order), template) + ax.bar(range(group_order), tpl) ax.set_xlabel("Group element index") ax.set_ylabel("Value") title = f"{group_label} Template (order={group_order}, type={template_type})" @@ -1007,14 +763,14 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print("Setting up model and training...") # Flatten template for model (works for both 1D and 2D) - template_torch = torch.tensor(template, device=device, dtype=torch.float32).flatten() + template_torch = torch.tensor(tpl, device=device, dtype=torch.float32).flatten() # Determine which model to use model_type = config["model"]["model_type"] print(f"Using model type: {model_type}") if model_type == "QuadraticRNN": - rnn_2d = QuadraticRNN( + rnn_2d = model.QuadraticRNN( p=p_flat, d=config["model"]["hidden_dim"], template=template_torch, @@ -1023,7 +779,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: transform_type=config["model"]["transform_type"], ).to(device) elif model_type == "SequentialMLP": - rnn_2d = SequentialMLP( + rnn_2d = model.SequentialMLP( p=p_flat, d=config["model"]["hidden_dim"], template=template_torch, @@ -1035,7 +791,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: hidden_dim = config["model"]["hidden_dim"] nonlinearity = config["model"].get("nonlinearity", "square") output_scale = config["model"].get("output_scale", 1.0) - rnn_2d = TwoLayerNet( + rnn_2d = model.TwoLayerNet( group_size=p_flat, hidden_size=hidden_dim, nonlinearity=nonlinearity, @@ -1064,7 +820,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print(f"Using optimizer: {optimizer_name}") if optimizer_name == "adam": - optimizer = optim.Adam( + opt = optim.Adam( rnn_2d.parameters(), lr=config["training"]["learning_rate"], betas=tuple(config["training"]["betas"]), @@ -1075,7 +831,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: raise ValueError( f"'hybrid' optimizer is only supported for QuadraticRNN, got {model_type}" ) - optimizer = HybridRNNOptimizer( + opt = optimizer.HybridRNNOptimizer( rnn_2d, lr=1, scaling_factor=config["training"]["scaling_factor"], @@ -1093,12 +849,12 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print(" Note: Using lr=1.0 for per_neuron optimizer with SequentialMLP") lr = 1.0 - optimizer = PerNeuronScaledSGD( + opt = optimizer.PerNeuronScaledSGD( rnn_2d, lr=lr, degree=degree, # Will auto-infer as k+1 for SequentialMLP (k = sequence length) ) - print(f" Degree of homogeneity: {optimizer.param_groups[0]['degree']}") + print(f" Degree of homogeneity: {opt.param_groups[0]['degree']}") else: raise ValueError( f"Invalid optimizer: {optimizer_name}. Must be 'adam', 'hybrid', or 'per_neuron'" @@ -1111,10 +867,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: print("Using ONLINE data generation...") if group_name == "cn": - from src.datamodule import OnlineModularAdditionDataset1D - # Training dataset - train_dataset = OnlineModularAdditionDataset1D( + train_dataset = dataset.OnlineModularAdditionDataset1D( p=config["data"]["p"], template=template_1d, k=config["data"]["k"], @@ -1124,7 +878,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) # Validation dataset - val_dataset = OnlineModularAdditionDataset1D( + val_dataset = dataset.OnlineModularAdditionDataset1D( p=config["data"]["p"], template=template_1d, k=config["data"]["k"], @@ -1133,10 +887,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: return_all_outputs=config["model"]["return_all_outputs"], ) elif group_name == "cnxcn": - from src.datamodule import OnlineModularAdditionDataset2D - # Training dataset - train_dataset = OnlineModularAdditionDataset2D( + train_dataset = dataset.OnlineModularAdditionDataset2D( p1=config["data"]["p1"], p2=config["data"]["p2"], template=template_2d, @@ -1147,7 +899,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) # Validation dataset - val_dataset = OnlineModularAdditionDataset2D( + val_dataset = dataset.OnlineModularAdditionDataset2D( p1=config["data"]["p1"], p2=config["data"]["p2"], template=template_2d, @@ -1178,26 +930,21 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: from torch.utils.data import TensorDataset if model_type == "TwoLayerNet": - # TwoLayerNet uses binary pair datasets from src/datasets.py + # TwoLayerNet uses binary pair datasets from src/datamodule.py # Data shape: X=(N, 2, group_size) -> flattened to (N, 2*group_size), Y=(N, group_size) - from src.datasets import ( - cn_dataset, - cnxcn_dataset, - group_dataset, - move_dataset_to_device_and_flatten, - ) - if group_name == "cn": - X_raw, Y_raw = cn_dataset(template) + X_raw, Y_raw = dataset.cn_dataset(tpl) elif group_name == "cnxcn": - X_raw, Y_raw = cnxcn_dataset(template) + X_raw, Y_raw = dataset.cnxcn_dataset(tpl) elif group_name in ("dihedral", "octahedral", "A5"): - X_raw, Y_raw = group_dataset(group, template) + X_raw, Y_raw = dataset.group_dataset(group, tpl) else: raise ValueError(f"Unsupported group_name for TwoLayerNet: {group_name}") # Flatten X from (N, 2, group_size) to (N, 2*group_size) and convert to tensors - X_all, Y_all, device = move_dataset_to_device_and_flatten(X_raw, Y_raw, device=device) + X_all, Y_all, device = dataset.move_dataset_to_device_and_flatten( + X_raw, Y_raw, device=device + ) # Apply dataset_fraction if configured dataset_fraction = config["data"].get("dataset_fraction", 1.0) @@ -1218,10 +965,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: else: # Sequence models (QuadraticRNN, SequentialMLP) use sequence datasets if group_name == "cn": - from src.datamodule import build_modular_addition_sequence_dataset_1d - # Generate training dataset - X_train, Y_train, _ = build_modular_addition_sequence_dataset_1d( + X_train, Y_train, _ = dataset.build_modular_addition_sequence_dataset_1d( config["data"]["p"], template_1d, config["data"]["k"], @@ -1232,7 +977,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: # Generate validation dataset val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_1d( + X_val, Y_val, _ = dataset.build_modular_addition_sequence_dataset_1d( config["data"]["p"], template_1d, config["data"]["k"], @@ -1241,10 +986,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: return_all_outputs=config["model"]["return_all_outputs"], ) elif group_name == "cnxcn": - from src.datamodule import build_modular_addition_sequence_dataset_2d - # Generate training dataset - X_train, Y_train, _ = build_modular_addition_sequence_dataset_2d( + X_train, Y_train, _ = dataset.build_modular_addition_sequence_dataset_2d( config["data"]["p1"], config["data"]["p2"], template_2d, @@ -1256,7 +999,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: # Generate validation dataset val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_2d( + X_val, Y_val, _ = dataset.build_modular_addition_sequence_dataset_2d( config["data"]["p1"], config["data"]["p2"], template_2d, @@ -1266,10 +1009,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: return_all_outputs=config["model"]["return_all_outputs"], ) elif group_name in ("dihedral", "octahedral", "A5"): - from src.datamodule import build_modular_addition_sequence_dataset_generic - - X_train, Y_train, _ = build_modular_addition_sequence_dataset_generic( - template, + X_train, Y_train, _ = dataset.build_modular_addition_sequence_dataset_generic( + tpl, config["data"]["k"], group=group, mode=config["data"]["mode"], @@ -1278,8 +1019,8 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: ) val_samples = max(1000, config["data"]["num_samples"] // 10) - X_val, Y_val, _ = build_modular_addition_sequence_dataset_generic( - template, + X_val, Y_val, _ = dataset.build_modular_addition_sequence_dataset_generic( + tpl, config["data"]["k"], group=group, mode="sampled", @@ -1323,34 +1064,38 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: start_time = time.time() if training_mode == "online": - from src.train import train_online - - train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = train_online( - rnn_2d, - train_loader, - criterion, - optimizer, - num_steps=num_steps, - verbose_interval=config["training"]["verbose_interval"], - grad_clip=config["training"]["grad_clip"], - eval_dataloader=val_loader, - save_param_interval=config["training"]["save_param_interval"], - reduction_threshold=reduction_threshold, + from src import train as train_mod + + train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = ( + train_mod.train_online( + rnn_2d, + train_loader, + criterion, + opt, + num_steps=num_steps, + verbose_interval=config["training"]["verbose_interval"], + grad_clip=config["training"]["grad_clip"], + eval_dataloader=val_loader, + save_param_interval=config["training"]["save_param_interval"], + reduction_threshold=reduction_threshold, + ) ) else: # offline - from src.train import train - - train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = train( - rnn_2d, - train_loader, - criterion, - optimizer, - epochs=epochs, - verbose_interval=config["training"]["verbose_interval"], - grad_clip=config["training"]["grad_clip"], - eval_dataloader=val_loader, - save_param_interval=config["training"]["save_param_interval"], - reduction_threshold=reduction_threshold, + from src import train as train_mod + + train_loss_hist, val_loss_hist, param_hist, param_save_indices, final_step = ( + train_mod.train( + rnn_2d, + train_loader, + criterion, + opt, + epochs=epochs, + verbose_interval=config["training"]["verbose_interval"], + grad_clip=config["training"]["grad_clip"], + eval_dataloader=val_loader, + save_param_interval=config["training"]["save_param_interval"], + reduction_threshold=reduction_threshold, + ) ) training_time = time.time() - start_time @@ -1373,7 +1118,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: train_loss_hist, val_loss_hist, param_hist, - template, + tpl, training_time, device, ) @@ -1413,7 +1158,7 @@ def train_single_run(config: dict, run_dir: Path = None) -> dict: param_hist=param_hist, param_save_indices=param_save_indices, train_loss_hist=train_loss_hist, - template=template, + template=tpl, device=device, group=group, ) diff --git a/src/optimizers.py b/src/optimizer.py similarity index 100% rename from src/optimizers.py rename to src/optimizer.py diff --git a/src/plot.py b/src/plot.py deleted file mode 100644 index 96d9705..0000000 --- a/src/plot.py +++ /dev/null @@ -1,546 +0,0 @@ -import collections - -import matplotlib.pyplot as plt -import numpy as np -import torch - -import src.power as power - -FONT_SIZES = {"title": 30, "axes_label": 30, "tick_label": 30, "legend": 15} - - -def plot_loss_curve(loss_history, template_power, save_path=None, show=False, freq_colors=None): - """Plot loss curve over epochs. - - Parameters - ---------- - loss_history : list of float - List of loss values recorded at each epoch. - template_power : class instance - Used to calculate theoretical plateau lines for AGF. - save_path : str, optional - Path to save the plot. If None, the plot is not saved. - show : bool, optional - Whether to display the plot. - freq_colors : list of str, optional - List of colors (in format "C0, C1, etc.") to use for different frequency intervals. - If None, a single color is used for the entire loss curve. - """ - fig = plt.figure(figsize=(6, 6)) - - plateau_predictions = template_power.loss_plateau_predictions() - print(f"Plotting {len(plateau_predictions)} theoretical plateau lines.") - - for k, alpha in enumerate(plateau_predictions): - print(f"Plotting alpha value {k}: {alpha}") - plt.axhline(y=alpha, color="black", linestyle="--", linewidth=2, zorder=-2) - - if freq_colors is None: - plt.plot(list(loss_history), lw=4) - else: - plateau_predictions = np.array(plateau_predictions) - num_alpha_intervals = len(plateau_predictions) - 1 - grouped_epochs = [[] for _ in range(num_alpha_intervals + 1)] - grouped_losses = [[] for _ in range(num_alpha_intervals + 1)] - - for epoch, loss in enumerate(loss_history): - in_interval = False - for ai in range(num_alpha_intervals): - if (loss <= plateau_predictions[ai] + 1e-1) and ( - loss > plateau_predictions[ai + 1] + 1e-1 - ): - grouped_epochs[ai].append(epoch) - grouped_losses[ai].append(loss) - in_interval = True - break - # Handle losses <= to the smallest (last) alpha value - include them in last group - if not in_interval and (loss <= plateau_predictions[-1] + 1e-1): - grouped_epochs[-1].append(epoch) - grouped_losses[-1].append(loss) - - print(f"Freq colors: {freq_colors}, number of alpha intervals: {num_alpha_intervals}") - for ai in range(num_alpha_intervals + 1): - color = freq_colors[ai] if ai < len(freq_colors) else freq_colors[-1] - if ai < num_alpha_intervals: - print( - f"Color for alpha value {ai} (alpha={plateau_predictions[ai]}): {color}, number of points: {len(grouped_epochs[ai])}" - ) - else: - print( - f"Color for alpha values < {plateau_predictions[-1]}: {color}, number of points: {len(grouped_epochs[ai])}" - ) - if grouped_epochs[ai]: # only plot if group is non-empty - plt.plot(grouped_epochs[ai], grouped_losses[ai], color=color, lw=4) - - plt.xscale("log") - plt.yscale("log") - - ymin, ymax = plt.ylim() - yticks = np.linspace(ymin, ymax, num=6) - yticklabels = [f"{t:.1e}" for t in yticks] - plt.yticks( - yticks, - yticklabels, - fontsize=FONT_SIZES["ticks"] if "ticks" in FONT_SIZES else 18, - ) - - tick_locs = [v for v in [100, 1000, 10000, 100000] if v < len(loss_history) - 1] - tick_labels = [rf"$10^{{{int(np.log10(loc))}}}$" for loc in tick_locs] - plt.xticks( - tick_locs, - tick_labels, - fontsize=FONT_SIZES["ticks"] if "ticks" in FONT_SIZES else 18, - ) - - plt.xlabel("Epochs", fontsize=FONT_SIZES["axes_label"]) - plt.ylabel("Train Loss", fontsize=FONT_SIZES["axes_label"]) - - # Cut off y-axis slightly below the lowest alpha value for higher resolution - y_min = plateau_predictions[-1] * 0.7 if plateau_predictions[-1] > 0 else 1e-8 - plt.ylim(bottom=y_min) - plt.xlim(0, len(loss_history) + 100) - - plt.grid(False) - plt.tight_layout() - - if save_path is not None: - plt.savefig(save_path, bbox_inches="tight") - if show: - plt.show() - return fig - - -def plot_training_power_over_time( - template_power_object, - model, - device, - param_history, - X_tensor, - group_name, - save_path=None, - logscale=False, - show=False, - return_freq_colors=False, -): - """Plot the power spectrum of the model's learned weights over time. - - Parameters - ---------- - template_power_object : class instance - Instance of <>Power containing the template power spectrum. - model : nn.Module - The trained model. - device : torch.device - Device to run computations on. - param_history : list of dict - List of parameter snapshots (as yielded by model.state_dict()) during training. - X_tensor : torch.Tensor - Input data tensor of shape (num_samples, ...). - group_name : str - Name of the group (should distinguish 'cnxcn'). - save_path : str, optional - Path to save the plot. If None, the plot is not saved. - logscale : bool, optional - Whether to use logarithmic scale for y-axis. - show : bool, optional - Whether to display the plot. - return_freq_colors : bool, optional - Whether to return the frequency colors used in the plot - (to optionally coordinate with loss curve). - """ - if group_name == "cnxcn": - escnn_group = None - row_freqs, column_freqs = ( - template_power_object.x_freqs, - template_power_object.y_freqs, - ) - freq = np.array( - [(row_freq, column_freq) for row_freq in row_freqs for column_freq in column_freqs] - ) - elif group_name == "cn": - escnn_group = None - freq = template_power_object.freqs - else: - escnn_group = template_power_object.group - freq = template_power_object.freqs - - template_power = template_power_object.power - template_power = np.where(template_power < 1e-20, 0, template_power) - flattened_template_power = template_power.flatten() - - power_idx = np.argsort(flattened_template_power)[-5:][::-1] - model_powers_over_time, steps = power.model_power_over_time( - group_name=group_name, - group=escnn_group, - model=model.to(device), - param_history=param_history, - model_inputs=X_tensor, - ) - - fig = plt.figure(figsize=(6, 7)) - - for i in power_idx: - if group_name == "cnxcn": - label = rf"$\xi = ({freq[i][0]:.1f}, {freq[i][1]:.1f})$" - elif group_name == "cn": - label = rf"$\xi = {freq[i]:.1f}$" - else: - label = rf"$\xi = {freq[i]:.1f} (dim={escnn_group.irreps()[i].size})$" - plt.plot(steps, model_powers_over_time[:, i], color=f"C{i}", lw=3, label=label) - plt.axhline( - flattened_template_power[i], - color=f"C{i}", - linestyle="dotted", - linewidth=2, - alpha=0.5, - zorder=-10, - ) - - ymin, ymax = plt.ylim() - if logscale: - plt.yscale("log") - yticks = np.logspace(np.log10(max(ymin, 1e-8)), np.log10(ymax), num=6) - yticklabels = [f"{t:.1e}" for t in yticks] - plt.yticks( - yticks, - yticklabels, - fontsize=FONT_SIZES["ticks"] if "ticks" in FONT_SIZES else 18, - ) - else: - yticks = np.linspace(ymin, ymax, num=6) - # Use scientific notation with one significant digit for yticks - yticklabels = [f"{t:.1e}" for t in yticks] - plt.yticks( - yticks, - yticklabels, - fontsize=FONT_SIZES["ticks"] if "ticks" in FONT_SIZES else 18, - ) - - plt.xscale("log") - plt.xlim(0, len(param_history) - 1) - tick_locs = [v for v in [100, 1000, 10000, 100000] if v < len(param_history) - 1] - tick_labels = [rf"$10^{{{int(np.log10(loc))}}}$" for loc in tick_locs] - plt.xticks( - tick_locs, - tick_labels, - fontsize=FONT_SIZES["ticks"] if "ticks" in FONT_SIZES else 18, - ) - - plt.ylabel("Power", fontsize=FONT_SIZES["axes_label"]) - plt.xlabel("Epochs", fontsize=FONT_SIZES["axes_label"]) - plt.legend( - fontsize=FONT_SIZES["legend"], - title="Frequency", - title_fontsize=FONT_SIZES["legend"], - loc="upper left", - labelspacing=0.25, - ) - plt.grid(False) - plt.tight_layout() - - if save_path is not None: - plt.savefig(save_path, bbox_inches="tight") - if show: - plt.show() - - if return_freq_colors: - freq_colors = [f"C{i}" for i in power_idx] - return fig, freq_colors - - return fig - - -def plot_neuron_weights( - config, - model, - neuron_indices=None, - save_path=None, - show=False, -): - """ - Plot the weights of specified neurons in the last linear layer of the model. - 2D visualization (imshow) if group is 'cnxcn', otherwise 1D line plot. - - Parameters - ---------- - config : dict - Configuration dictionary containing 'group_name' and 'group_size'. - model : nn.Module - The trained model. - neuron_indices : list of int or int, optional - Indices of neurons to plot. If None, randomly selects 10 neurons (or all if <=16). - save_path : str, optional - Path to save the plot. If None, the plot is not saved. - show : bool, optional - If True, display the plot window. - """ - # Get the last linear layer's weights - last_layer = None - modules = list(model.modules()) - for module in reversed(modules): - if hasattr(module, "weight") and hasattr(module, "bias"): - last_layer = module - weights = last_layer.weight.detach().cpu().numpy() - break - if last_layer is None: - if hasattr(model, "U"): - weights = model.U.detach().cpu().numpy() - elif last_layer is not None: - weights = last_layer.weight.detach().cpu().numpy() - else: - raise ValueError( - "No suitable weights found in model (neither nn.Linear nor custom nn.Parameter 'U')." - ) - - # Select neurons - if neuron_indices is None: - if len(weights) <= 16: - neuron_indices = list(range(len(weights))) - else: - neuron_indices = np.random.choice(range(len(weights)), 10, replace=False) - if isinstance(neuron_indices, int): - neuron_indices = [neuron_indices] - - # Setup subplots - n_plots = len(neuron_indices) - n_cols = min(5, n_plots) - n_rows = (n_plots + 4) // 5 - fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows)) - axs = np.array(axs).reshape(-1) - - for i, idx in enumerate(neuron_indices): - w = weights[idx] - if w.shape[0] != config["group_size"]: - raise ValueError( - f"Expected weight size group_size={config['group_size']}, got {weights.shape[0]}" - ) - if config["group_name"] == "cnxcn": # 2D irreps - img_len = int(np.sqrt(config["group_size"])) - w_img = w.reshape(img_len, img_len) - axs[i].imshow(w_img, cmap="viridis") - axs[i].set_title(f"Neuron {idx}") - axs[i].axis("off") - else: # 1D irreps - axs[i].plot(np.arange(config["group_size"]), w, lw=2) - axs[i].set_title(f"Neuron {idx}") - axs[i].set_xlabel("Input Index") - axs[i].set_ylabel("Weight Value") - # Hide unused subplots - for j in range(len(neuron_indices), len(axs)): - axs[j].axis("off") - plt.tight_layout() - if save_path is not None: - plt.savefig(save_path, bbox_inches="tight") - if show: - plt.show() - plt.close(fig) - return fig - - -def plot_model_outputs( - group_name, - group_size, - model, - X, - Y, - idx, - step=-1, - param_history=None, - save_path=None, - show=False, -): - """ - Plot a training target vs the model output, adapting plot style for cnxcn (2D) and other groups (1D). - - Parameters - ---------- - group_name : object or str - The group instance or name (should distinguish 'cnxcn'). - group_size : int - The value of group_size. - model : nn.Module - The trained model. - X : torch.Tensor - Input data of shape (num_samples, ...) for the group. - Y : torch.Tensor - Target data of shape (num_samples, output_size). - idx : int or list-like - Index/indices of the data sample(s) to plot target vs output for. - step : int, optional - Snapshot step from param_history to visualize. If -1, uses current model parameters. - param_history : list of dict, optional - List of parameter snapshots (as yielded by model.state_dict()) during training. - If not provided, uses the latest model state. - save_path : str, optional - Path to save the plot. If None, the plot is not saved. - show : bool, optional - If True, display the plot window. - Returns - ------- - fig : matplotlib.figure.Figure - The resulting matplotlib figure handle. - """ - with torch.no_grad(): - # Accept single int or list/array of ints for idx - if isinstance(idx, collections.abc.Sequence) and not isinstance(idx, str): - idx_list = list(idx) - else: - idx_list = [idx] - n_samples = len(idx_list) - - # Restore model parameters for a specific step if param_history is provided - if param_history is not None and step is not None and isinstance(step, int): - # Clamp step index - s_idx = step - if s_idx < 0: - s_idx = len(param_history) + s_idx - s_idx = max(0, min(s_idx, len(param_history) - 1)) - model.load_state_dict(param_history[s_idx]) - - # Prepare output for all idx - x = X[idx_list] - y = Y[idx_list] - - if hasattr(model, "eval"): - model.eval() - output = model(x) - - # Convert tensors to numpy arrays - def to_numpy(t): - if torch.is_tensor(t): - return t.detach().cpu().numpy() - return np.array(t) - - x_np = to_numpy(x) - y_np = to_numpy(y) - output_np = to_numpy(output) - - plot_is_2D = group_name == "cnxcn" - - # --- 2D plotting for cnxcn --- - if plot_is_2D: - image_size = int(np.sqrt(group_size)) - input_flat_dim = x_np.shape[-1] - split_point = input_flat_dim // 2 if x_np.ndim == 2 else 0 - - fig, axs = plt.subplots( - n_samples, 4, figsize=(15, 3 * n_samples), sharey=True, squeeze=False - ) - - for row, (x_item, output_item, y_item) in enumerate(zip(x_np, output_np, y_np)): - # Flatten and squeeze to expected shapes - x_item = np.squeeze(x_item) - y_item = np.squeeze(y_item) - output_item = np.squeeze(output_item) - # For input, show two images side by side if x_item has two images - axs[row, 0].imshow( - x_item[:split_point].reshape(image_size, image_size), cmap="viridis" - ) - axs[row, 0].set_title("Input 1") - - axs[row, 1].imshow( - x_item[split_point:].reshape(image_size, image_size), cmap="viridis" - ) - axs[row, 1].set_title("Input 2") - - axs[row, 2].imshow(output_item.reshape(image_size, image_size), cmap="viridis") - axs[row, 2].set_title("Output") - - axs[row, 3].imshow(y_item.reshape(image_size, image_size), cmap="viridis") - axs[row, 3].set_title("Target") - for col in range(4): - axs[row, col].axis("off") - - suptitle_str = f"Model Inputs, Outputs, and Targets at index {idx}" - if param_history is not None and step is not None and isinstance(step, int): - suptitle_str += f" (step {s_idx})" - fig.suptitle(suptitle_str, fontsize=FONT_SIZES["title"]) - plt.tight_layout() - - # --- 1D plotting for other groups --- - else: - fig, axs = plt.subplots( - n_samples, 2, figsize=(12, 3 * n_samples), sharey=True, squeeze=False - ) - - for row, (output_item, y_item) in enumerate(zip(output_np, y_np)): - # ensure all items are 1d or flat - y_item = np.squeeze(y_item) - output_item = np.squeeze(output_item) - - axs[row, 0].plot(np.arange(group_size), output_item, lw=2) - axs[row, 0].set_title("Output") - - axs[row, 1].plot(np.arange(group_size), y_item, lw=2) - axs[row, 1].set_title("Target") - for col in range(2): - axs[row, col].set_xlabel("Index") - axs[row, col].set_ylabel("Value") - - suptitle_str = f"Model Outputs and Targets at index {idx}" - if param_history is not None and step is not None and isinstance(step, int): - suptitle_str += f" (step {s_idx})" - fig.suptitle(suptitle_str, fontsize=FONT_SIZES["title"]) - plt.tight_layout() - - if save_path is not None: - plt.savefig(save_path, bbox_inches="tight") - - if show: - plt.show() - plt.close(fig) - - return fig - - -def plot_irreps(group, show=False): - """Plot the irreducible representations (irreps) of the group and their corresponding power in the template. - - Parameters - ---------- - group : class instance - The group for which the irreps are being plotted. Should have a method `get_irreps()` that returns a list of irreps. - show : bool, optional - Whether to display the plot immediately. If False, the plot is not shown but the figure object is returned. - """ - irreps = group.irreps() - - group_elements = group.elements - irreps = group.irreps() - - num_irreps = len(irreps) - fig, axs = plt.subplots(1, num_irreps, figsize=(3 * num_irreps, 4), squeeze=False) - axs = axs[0] - - for i, irrep in enumerate(irreps): - # Evaluate irrep on all group elements - matrices = [irrep(g) for g in group_elements] - matrices = np.array(matrices) # (num_elements, d, d) or (num_elements,) for 1D - - if matrices.ndim == 1 or (matrices.ndim == 2 and matrices.shape[1] == 1): - # 1D irrep: plot as real line (vs. group element index) - axs[i].plot(range(len(group_elements)), matrices.real, marker="o", label="Re") - if np.any(np.abs(matrices.imag) > 1e-10): - axs[i].plot(range(len(group_elements)), matrices.imag, marker="x", label="Im") - axs[i].set_title(f"Irrep {i}: {str(irrep)} (dim=1)") - axs[i].set_xlabel("Group element idx") - axs[i].set_ylabel("Irrep value") - axs[i].legend() - else: - d = matrices.shape[1] - num_group_elements = len(group_elements) - num_irrep_entries = d * d - irrep_matrix_entries = matrices.real.reshape(num_group_elements, num_irrep_entries) - im = axs[i].imshow(irrep_matrix_entries, aspect="auto", cmap="viridis") - axs[i].set_title(f"Irrep {i}: {str(irrep)} (size={d}x{d})") - axs[i].set_xlabel("Flattened Irreps") - axs[i].set_ylabel("Irrep(g)") - plt.colorbar(im, ax=axs[i]) - fig.suptitle( - "Irreducible Representations (matrix values for all group elements)", - fontsize=FONT_SIZES["title"], - ) - plt.tight_layout() - if show: - plt.show() - return fig diff --git a/src/power.py b/src/power.py index 9519e6b..5376bbe 100644 --- a/src/power.py +++ b/src/power.py @@ -1,7 +1,7 @@ import numpy as np import torch -import src.group_fourier_transform as gft +import src.fourier as fourier class CyclicPower: @@ -185,12 +185,13 @@ def group_power_spectrum(self): power_spectrum : np.ndarray, shape=[len(group.irreps())] The power spectrum of the template. """ + fourier_coefs = fourier.group_fourier(self.group, self.template) irreps = self.group.irreps() power_spectrum = np.zeros(len(irreps)) for i, irrep in enumerate(irreps): - fourier_coef = gft.compute_group_fourier_coef(self.group, self.template, irrep) - power_spectrum[i] = irrep.size * np.trace(fourier_coef.conj().T @ fourier_coef) + fc = fourier_coefs[i] + power_spectrum[i] = irrep.size * np.trace(fc.conj().T @ fc) power_spectrum = power_spectrum / self.group.order() return np.array(power_spectrum) @@ -311,3 +312,220 @@ def model_power_over_time(group_name, model, param_history, model_inputs, group= powers_over_time[powers_over_time < 1e-20] = 0 return powers_over_time, steps + + +# --------------------------------------------------------------------------- +# Power spectrum computation functions (moved from utils.py) +# --------------------------------------------------------------------------- + + +def get_power_1d(points_1d): + """Compute 1D power spectrum using rfft (for real-valued inputs). + + Args: + points_1d: (p,) array + + Returns: + power: (p//2+1,) array of power values + freqs: frequency indices + """ + p = len(points_1d) + + ft = np.fft.rfft(points_1d) + power = np.abs(ft) ** 2 / p + + power = 2 * power.copy() + power[0] = power[0] / 2 # DC component + if p % 2 == 0: + power[-1] = power[-1] / 2 # Nyquist frequency + + freqs = np.fft.rfftfreq(p, 1.0) * p + + return power, freqs + + +def topk_template_freqs_1d(template_1d: np.ndarray, K: int, min_power: float = 1e-20): + """Return top-K frequency indices by power for 1D template. + + Args: + template_1d: 1D template array (p,) + K: Number of top frequencies to return + min_power: Minimum power threshold + + Returns: + List of frequency indices (as integers) + """ + power, _ = get_power_1d(template_1d) + mask = power > min_power + if not np.any(mask): + return [] + valid_power = power[mask] + valid_indices = np.flatnonzero(mask) + top_idx = valid_indices[np.argsort(valid_power)[::-1]][:K] + return top_idx.tolist() + + +def topk_template_freqs(template_2d: np.ndarray, K: int, min_power: float = 1e-20): + """Return top-K (kx, ky) rFFT2 bins by power from get_power_2d(template_2d).""" + freqs_u, freqs_v, power = get_power_2d(template_2d) + shp = power.shape + flat = power.ravel() + mask = flat > min_power + if not np.any(mask): + return [] + top_idx = np.flatnonzero(mask)[np.argsort(flat[mask])[::-1]][:K] + kx, ky = np.unravel_index(top_idx, shp) + return list(zip(kx.tolist(), ky.tolist())) + + +def get_power_2d(points, no_freq=False): + """Compute 2D power spectrum using rfft2 with proper symmetry handling. + + Args: + points: (M, N) array, the 2D signal + no_freq: if True, only return power (no frequency arrays) + + Returns: + freqs_u: frequency bins for rows (if no_freq=False) + freqs_v: frequency bins for columns (if no_freq=False) + power: 2D power spectrum (M, N//2 + 1) + """ + M, N = points.shape + + ft = np.fft.rfft2(points) + power = np.abs(ft) ** 2 / (M * N) + + weight = 2 * np.ones((M, N // 2 + 1)) + weight[0, 0] = 1 + weight[(M // 2 + 1) :, 0] = 0 + if M % 2 == 0: + weight[M // 2, 0] = 1 + if N % 2 == 0: + weight[(M // 2 + 1) :, N // 2] = 0 + weight[0, N // 2] = 1 + if (M % 2 == 0) and (N % 2 == 0): + weight[M // 2, N // 2] = 1 + + power = weight * power + + total_power = np.sum(power) + norm_squared = np.linalg.norm(points) ** 2 + if not np.isclose(total_power, norm_squared, rtol=1e-6): + print( + f"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}" + ) + + if no_freq: + return power + + freqs_u = np.fft.fftfreq(M) + freqs_v = np.fft.rfftfreq(N) + + return freqs_u, freqs_v, power + + +def _tracked_power_from_fft2(power2d, kx, ky, p1, p2): + """Sum power at (kx, ky) and its real-signal mirror (-kx, -ky). + + Args: + power2d: 2D power spectrum from fft2 (shape: p1, p2) + kx, ky: Frequency indices + p1, p2: Dimensions of the signal + + Returns: + float: Total power at this frequency (including mirror) + """ + i0, j0 = kx % p1, ky % p2 + i1, j1 = (-kx) % p1, (-ky) % p2 + if (i0, j0) == (i1, j1): + return float(power2d[i0, j0]) + return float(power2d[i0, j0] + power2d[i1, j1]) + + +def theoretical_loss_levels_2d(template_2d): + """Compute theoretical MSE loss levels based on 2D template power spectrum. + + Args: + template_2d: 2D template array (p1, p2) + + Returns: + dict with 'initial', 'final', and 'levels' keys + """ + p1, p2 = template_2d.shape + power = get_power_2d(template_2d, no_freq=True) + + power_flat = power.flatten() + power_flat = np.sort(power_flat[power_flat > 1e-20])[::-1] + + coef = 1.0 / (p1 * p2) + levels = [coef * np.sum(power_flat[k:]) for k in range(len(power_flat) + 1)] + + return { + "initial": levels[0] if levels else 0.0, + "final": 0.0, + "levels": levels, + } + + +def theoretical_loss_levels_1d(template_1d): + """Compute theoretical MSE loss levels based on 1D template power spectrum. + + Args: + template_1d: 1D template array (p,) + + Returns: + dict with 'initial', 'final', and 'levels' keys + """ + p = len(template_1d) + power, _ = get_power_1d(template_1d) + + power = np.sort(power[power > 1e-20])[::-1] + + coef = 1.0 / p + levels = [coef * np.sum(power[k:]) for k in range(len(power) + 1)] + + return { + "initial": levels[0] if levels else 0.0, + "final": 0.0, + "levels": levels, + } + + +# Backward compatibility aliases +def theoretical_final_loss_2d(template_2d): + """Returns expected initial loss (for setting convergence targets).""" + return theoretical_loss_levels_2d(template_2d)["initial"] + + +def theoretical_final_loss_1d(template_1d): + """Returns expected initial loss (for setting convergence targets).""" + return theoretical_loss_levels_1d(template_1d)["initial"] + + +def group_power_spectrum(group, template): + """Compute the (group) power spectrum of the template. + + For each irrep rho, the power is given by: + ||hat x(rho)||_rho = dim(rho) * Tr(hat x(rho)^dagger * hat x(rho)) + + Parameters + ---------- + group : Group (escnn object) + The group. + template : np.ndarray, shape=[group.order()] + The template to compute the power spectrum of. + + Returns + ------- + power_spectrum : np.ndarray, shape=[len(group.irreps())] + The power spectrum of the template. + """ + fourier_coefs = fourier.group_fourier(group, template) + irreps = group.irreps() + + power_spectrum = np.zeros(len(irreps)) + for i, irrep in enumerate(irreps): + fc = fourier_coefs[i] + power_spectrum[i] = irrep.size * np.trace(fc.conj().T @ fc) + power_spectrum = power_spectrum / group.order() + return np.array(power_spectrum) diff --git a/src/run_sweep.py b/src/run_sweep.py index 798371c..18497bf 100644 --- a/src/run_sweep.py +++ b/src/run_sweep.py @@ -270,10 +270,10 @@ def run_single_seed( try: # Import here to avoid circular dependency - from src.main import train_single_run + import src.main as main # Run training - result = train_single_run(seed_config, run_dir=seed_dir) + result = main.train_single_run(seed_config, run_dir=seed_dir) # Save run summary run_summary = { diff --git a/src/template.py b/src/template.py new file mode 100644 index 0000000..12589ea --- /dev/null +++ b/src/template.py @@ -0,0 +1,667 @@ +import numpy as np +from skimage.transform import resize +from sklearn.datasets import fetch_openml +from sklearn.utils import shuffle + +import src.fourier as fourier + +# --------------------------------------------------------------------------- +# Templates from the original templates.py (names kept as-is per plan) +# --------------------------------------------------------------------------- + + +def one_hot(p): + """One-hot encode an integer value in R^p.""" + vec = np.zeros(p) + vec[1] = 10 + + zeroth_freq = np.mean(vec) + vec = vec - zeroth_freq + return vec + + +def fixed_cn(group_size, fourier_coef_mags): + """Generate a fixed template for the 1D modular addition dataset. + + Parameters + ---------- + group_size : int + n in Cn. Number of elements in the 1D modular addition + fourier_coef_mags : list of float + Magnitudes of the Fourier coefficients to set. + + Returns + ------- + template : np.ndarray + A 1D array of shape (group_size,) representing the template. + """ + spectrum = np.zeros(group_size, dtype=complex) + + spectrum[0] = fourier_coef_mags[0] + fourier_coef_mags = fourier_coef_mags[1:] + + for i_mag, mag in enumerate(fourier_coef_mags): + mode = i_mag + 1 + spectrum[mode] = mag + spectrum[-mode] = np.conj(mag) + print("Setting mode:", mode, "with magnitude:", mag) + + template = np.fft.ifft(spectrum).real + + zeroth_freq = np.mean(template) + template = template - zeroth_freq + + return template + + +def fixed_cnxcn(image_length, fourier_coef_mags): + """Generate a fixed template for the 2D modular addition dataset. + + Parameters + ---------- + image_length : int + image_length = n in Cn x Cn. + fourier_coef_mags : list of float + Magnitudes of the Fourier coefficients to set. + + Returns + ------- + template : np.ndarray + A flattened 2D array of shape (image_length*image_length,). + """ + spectrum = np.zeros((image_length, image_length), dtype=complex) + + spectrum[0, 0] = fourier_coef_mags[0] + fourier_coef_mags = fourier_coef_mags[1:] + + def mode_selector(i_mag): + i_mode = 1 + i_mag // 3 + mode_type = i_mag % 3 + if mode_type == 0: + return (i_mode, 0) + elif mode_type == 1: + return (0, i_mode) + else: + return (i_mode, i_mode) + + i_mag = 0 + while i_mag < len(fourier_coef_mags): + mode = mode_selector(i_mag) + + spectrum[mode[0], mode[1]] = fourier_coef_mags[i_mag] + spectrum[-mode[0], -mode[1]] = np.conj(fourier_coef_mags[i_mag]) + print("Setting mode:", mode, "with magnitude:", fourier_coef_mags[i_mag]) + i_mag += 1 + + template = np.fft.ifft2(spectrum).real + + template = template.flatten() + + zeroth_freq = np.mean(template) + template = template - zeroth_freq + + return template + + +def fixed_group(group, fourier_coef_diag_values): + """Generate a fixed template for a group with non-zero Fourier coefficients for specific irreps. + + Parameters + ---------- + group : Group (escnn object) + The group. + fourier_coef_diag_values : list of float + Diagonal values for each irrep's Fourier coefficient matrix. + + Returns + ------- + template : np.ndarray, shape=[group.order()] + The mean centered template. + """ + spectrum = [] + assert len(fourier_coef_diag_values) == len(group.irreps()), ( + f"Number of Fourier coef. magnitudes on the diagonal {len(fourier_coef_diag_values)} must match number of irreps {len(group.irreps())}" + ) + for i, irrep in enumerate(group.irreps()): + diag_values = np.full(irrep.size, fourier_coef_diag_values[i], dtype=float) + mat = np.zeros((irrep.size, irrep.size), dtype=float) + np.fill_diagonal(mat, diag_values) + print(f"mat for irrep {i} of dimension {irrep.size} is:\n {mat}\n") + + spectrum.append(mat) + + template = fourier.group_fourier_inverse(group, spectrum) + + zeroth_freq = np.mean(template) + template = template - zeroth_freq + + return template + + +def mnist(image_length, digit=0, sample_idx=0, random_state=42): + """Generate a template from the MNIST dataset, resized to p x p. + + Parameters + ---------- + image_length : int + p in Z/pZ x Z/pZ. + digit : int, optional + The MNIST digit to use as a template (0-9). + sample_idx : int, optional + The index of the sample to use. + random_state : int, optional + Random seed for shuffling. + + Returns + ------- + template : np.ndarray + A flattened 2D array of shape (image_length*image_length,). + """ + mnist = fetch_openml("mnist_784", version=1) + X = mnist.data.values + y = mnist.target.astype(int).values + + X_digit = X[y == digit] + + if X_digit.shape[0] == 0: + raise ValueError(f"No samples found for digit {digit} in MNIST dataset.") + + X_digit = shuffle(X_digit, random_state=random_state) + if sample_idx >= X_digit.shape[0]: + raise IndexError( + f"sample_idx {sample_idx} is out of bounds for digit {digit} (found {X_digit.shape[0]} samples)." + ) + sample = X_digit[sample_idx].reshape(28, 28) + + sample_resized = resize(sample, (image_length, image_length), anti_aliasing=True) + + sample_resized = (sample_resized - np.min(sample_resized)) / ( + np.max(sample_resized) - np.min(sample_resized) + ) + + template = sample_resized.flatten() + + zeroth_freq = np.mean(template) + template = template - zeroth_freq + + return template + + +def template_selector(config): + """Select template based on configuration.""" + if config["template_type"] == "irrep_construction": + if config["group_name"] == "cnxcn": + template = fixed_cnxcn(config["image_length"], config["fourier_coef_diag_values"]) + elif config["group_name"] == "cn": + template = fixed_cn(config["group_n"], config["fourier_coef_diag_values"]) + else: + template = fixed_group(config["group"], config["fourier_coef_diag_values"]) + elif config["template_type"] == "one_hot": + template = one_hot(config["group_size"]) + else: + raise ValueError(f"Unknown template type: {config['template_type']}") + return template + + +# --------------------------------------------------------------------------- +# Template functions moved from datamodule.py (renamed per plan) +# --------------------------------------------------------------------------- + + +def mnist_1d(p: int, label: int, root: str = "data", axis: int = 0): + """Return a (p,) 1D template from a random MNIST image by taking a slice or projection. + + Args: + p: dimension of the cyclic group + label: MNIST digit class (0-9) + root: MNIST data directory + axis: 0 for row average, 1 for column average, 2 for diagonal + + Returns: + template: (p,) array + """ + import torch + import torchvision + import torchvision.transforms as transforms + + if not (0 <= int(label) <= 9): + raise ValueError("label must be an integer in [0, 9].") + + ds = torchvision.datasets.MNIST( + root=root, train=True, download=True, transform=transforms.ToTensor() + ) + cls_idxs = (ds.targets == int(label)).nonzero(as_tuple=True)[0] + if cls_idxs.numel() == 0: + raise ValueError(f"No samples for label {label}.") + + idx = cls_idxs[torch.randint(len(cls_idxs), (1,)).item()].item() + img, _ = ds[idx] + img = img[0].numpy() + + if axis == 0: + signal = img.mean(axis=1) + elif axis == 1: + signal = img.mean(axis=0) + elif axis == 2: + signal = np.diag(img) + else: + raise ValueError("axis must be 0, 1, or 2") + + from scipy.interpolate import interp1d + + x_old = np.linspace(0, 1, len(signal)) + x_new = np.linspace(0, 1, p) + f = interp1d(x_old, signal, kind="cubic") + template = f(x_new) + + return template.astype(np.float32) + + +def mnist_2d(p1: int, p2: int, label: int, root: str = "data"): + """Return a (p1, p2) template from a random MNIST image. + + Args: + p1, p2: dimensions + label: MNIST digit class (0-9) + root: MNIST data directory + + Returns: + template: (p1, p2) array + """ + import torch + import torch.nn as nn + import torchvision + import torchvision.transforms as transforms + + if not (0 <= int(label) <= 9): + raise ValueError("label must be an integer in [0, 9].") + + ds = torchvision.datasets.MNIST( + root=root, train=True, download=True, transform=transforms.ToTensor() + ) + cls_idxs = (ds.targets == int(label)).nonzero(as_tuple=True)[0] + if cls_idxs.numel() == 0: + raise ValueError(f"No samples for label {label}.") + + idx = cls_idxs[torch.randint(len(cls_idxs), (1,)).item()].item() + img, _ = ds[idx] + img = nn.functional.interpolate( + img.unsqueeze(0), size=(p1, p2), mode="bilinear", align_corners=False + )[0, 0] + return img.numpy().astype(np.float32) + + +# --- 1D Synthetic Templates --- + + +def fourier_1d(p: int, n_freqs: int, amp_max: float = 100, amp_min: float = 10, seed=None): + """Generate 1D template from random Fourier modes. + + Args: + p: dimension of cyclic group + n_freqs: number of frequency components to include + amp_max: maximum amplitude + amp_min: minimum amplitude + seed: random seed + + Returns: + template: (p,) real-valued array + """ + rng = np.random.default_rng(seed) + spectrum = np.zeros(p, dtype=np.complex128) + + available_freqs = list(range(1, p // 2 + 1)) + if len(available_freqs) < n_freqs: + raise ValueError( + f"Only {len(available_freqs)} non-DC frequencies available for p={p}, requested {n_freqs}" + ) + + chosen_freqs = rng.choice( + available_freqs, size=min(n_freqs, len(available_freqs)), replace=False + ) + + amps = np.sqrt(np.linspace(amp_max, amp_min, len(chosen_freqs))) + phases = rng.uniform(0.0, 2 * np.pi, size=len(chosen_freqs)) + + for freq, amp, phi in zip(chosen_freqs, amps, phases): + v = amp * np.exp(1j * phi) + spectrum[freq] = v + spectrum[-freq] = np.conj(v) + + template = np.fft.ifft(spectrum).real + template -= template.mean() + s = template.std() + if s > 1e-12: + template /= s + + return template.astype(np.float32) + + +def gaussian_1d(p: int, n_gaussians: int = 3, sigma_range: tuple = (0.5, 2.0), seed=None): + """Generate 1D template as sum of Gaussians. + + Args: + p: dimension of cyclic group + n_gaussians: number of Gaussian bumps + sigma_range: (min_sigma, max_sigma) for Gaussian widths + seed: random seed + + Returns: + template: (p,) real-valued array + """ + rng = np.random.default_rng(seed) + x = np.arange(p) + template = np.zeros(p, dtype=np.float32) + + for _ in range(n_gaussians): + center = rng.uniform(0, p) + sigma = rng.uniform(*sigma_range) + amplitude = rng.uniform(0.5, 1.0) + + dist = np.minimum(np.abs(x - center), p - np.abs(x - center)) + template += amplitude * np.exp(-(dist**2) / (2 * sigma**2)) + + template -= template.mean() + s = template.std() + if s > 1e-12: + template /= s + + return template.astype(np.float32) + + +def onehot_1d(p: int): + """Generate 1D one-hot template for cyclic group C_p. + + Args: + p: dimension of cyclic group + + Returns: + template: (p,) array with template[0] = 1, all others = 0 + """ + template = np.zeros(p, dtype=np.float32) + template[0] = 1.0 + return template + + +# --- 2D Synthetic Templates --- + + +def gaussian_mixture_2d( + p1=20, + p2=20, + n_blobs=8, + frac_broad=0.7, + sigma_broad=(3.5, 6.0), + sigma_narrow=(1.0, 2.0), + amp_broad=1.0, + amp_narrow=0.5, + seed=None, + normalize=True, +): + """Build a (p1 x p2) template as a periodic mixture of Gaussians.""" + rng = np.random.default_rng(seed) + H, W = p1, p2 + Y, X = np.meshgrid(np.arange(H), np.arange(W), indexing="ij") + + k_broad = int(round(n_blobs * frac_broad)) + k_narrow = n_blobs - k_broad + + def add_blobs(k, sigma_range, amp): + out = np.zeros((H, W), dtype=float) + for _ in range(k): + cy, cx = rng.uniform(0, H), rng.uniform(0, W) + sigma = rng.uniform(*sigma_range) + dy = np.minimum(np.abs(Y - cy), H - np.abs(Y - cy)) + dx = np.minimum(np.abs(X - cx), W - np.abs(X - cx)) + out += amp * np.exp(-(dx**2 + dy**2) / (2.0 * sigma**2)) + return out + + template = add_blobs(k_broad, sigma_broad, amp_broad) + add_blobs( + k_narrow, sigma_narrow, amp_narrow + ) + + if normalize: + template -= template.mean() + s = template.std() + if s > 1e-12: + template /= s + return template.astype(np.float32) + + +def unique_freqs_2d(p1, p2, n_freqs, amp_max=100, amp_min=10, seed=None): + """Real (p1 x p2) template from n_freqs unique Fourier modes. + + Each chosen frequency bin has no conjugate collision. + + Args: + p1, p2: spatial dims + n_freqs: number of frequency components + amp_max, amp_min: amplitude range + seed: random seed + + Returns: + template: (p1, p2) real-valued array + """ + rng = np.random.default_rng(seed) + spectrum = np.zeros((p1, p2), dtype=np.complex128) + + def ky_signed(ky): + return ky if ky <= p1 // 2 else ky - p1 + + def is_self_conj(ky, kx): + on_self_kx = (kx == 0) or (p2 % 2 == 0 and kx == p2 // 2) + if not on_self_kx: + return False + s = ky_signed(ky) + return (s == 0) or (p1 % 2 == 0 and abs(s) == p1 // 2) + + cand = [] + for ky in range(p1): + s = ky_signed(ky) + for kx in range(p2 // 2 + 1): + if ky == 0 and kx == 0: + continue + if is_self_conj(ky, kx): + continue + r2 = (s**2) + (kx**2) + cand.append((r2, ky, kx)) + cand.sort(key=lambda t: (t[0], abs(ky_signed(t[1])), t[2])) + + chosen = [] + seen_axis_pairs = set() + + mid_kx = p2 // 2 if (p2 % 2 == 0) else None + for _, ky, kx in cand: + if len(chosen) >= n_freqs: + break + if (kx == 0) or (mid_kx is not None and kx == mid_kx): + s = ky_signed(ky) + key = (kx, min(s, -s)) + if key in seen_axis_pairs: + continue + seen_axis_pairs.add(key) + chosen.append((ky, kx)) + else: + chosen.append((ky, kx)) + + if len(chosen) < n_freqs: + raise ValueError( + f"Could only find {len(chosen)} unique non-conjugate bins; " + f"requested {n_freqs}. Increase grid size or reduce n_freqs." + ) + + amps = np.sqrt(np.linspace(amp_max, amp_min, n_freqs, dtype=float)) + phases = rng.uniform(0.0, 2 * np.pi, size=n_freqs) + + for (ky, kx), a, phi in zip(chosen, amps, phases): + kyc, kxc = (-ky) % p1, (-kx) % p2 + v = a * np.exp(1j * phi) + spectrum[ky, kx] += v + spectrum[kyc, kxc] += np.conj(v) + + template = np.fft.ifft2(spectrum).real + template -= template.mean() + s = template.std() + if s > 1e-12: + template /= s + return template.astype(np.float32) + + +def fixed_2d(p1: int, p2: int) -> np.ndarray: + """Generate 2D template array from Fourier spectrum. + + Args: + p1: height dimension + p2: width dimension + + Returns: + template: (p1, p2) real-valued array + """ + spectrum = np.zeros((p1, p2), dtype=complex) + + assert p1 > 5 and p2 > 5, "p1 and p2 must be greater than 5" + + spectrum[1, 0] = 10.0 + spectrum[-1, 0] = 10.0 + + spectrum[0, 3] = 7.5 + spectrum[0, -3] = 7.5 + + spectrum[2, 1] = 5.0 + spectrum[-2, -1] = 5.0 + + template = np.fft.ifft2(spectrum).real + + return template + + +def _fft_indices(n): + """Return integer-like frequency indices aligned with numpy's FFT layout.""" + k = np.fft.fftfreq(n) * n + return k.astype(int) + + +def hexagon_tie_2d(p1: int, p2: int, k0: float = 6.0, amp: float = 1.0): + """Real template with hexagonal Fourier spectrum. + + Args: + p1, p2: spatial dims + k0: desired radius (index units) + amp: amplitude per spike + + Returns: + template: (p1, p2) real-valued array + """ + assert p1 > 5 and p2 > 5, "p1 and p2 must be > 5" + spec = np.zeros((p1, p2), dtype=np.complex128) + + thetas = np.arange(6) * (np.pi / 3.0) + + Kx = _fft_indices(p1) + Ky = _fft_indices(p2) + + def put(kx, ky, val): + spec[int(kx) % p1, int(ky) % p2] += val + + used = set() + for th in thetas: + kx_f = k0 * np.cos(th) + ky_f = k0 * np.sin(th) + kx = int(np.round(kx_f)) + ky = int(np.round(ky_f)) + if (kx, ky) == (0, 0): + if abs(np.cos(th)) > abs(np.sin(th)): + kx = 1 if kx >= 0 else -1 + else: + ky = 1 if ky >= 0 else -1 + if (kx, ky) in used: + continue + used.add((kx, ky)) + used.add((-kx, -ky)) + + put(kx, ky, amp) + put(-kx, -ky, np.conjugate(amp)) + + spec[0, 0] = 0.0 + + x = np.fft.ifft2(spec).real + return x + + +def ring_isotropic_2d( + p1: int, p2: int, r0: float = 6.0, sigma: float = 0.5, total_power: float = 1.0 +): + """Real template with an isotropic ring in the 2D spectrum. + + Args: + p1, p2: spatial dims + r0: target radius (index units) + sigma: radial width of the ring + total_power: scales overall energy + + Returns: + template: (p1, p2) real-valued array + """ + assert p1 > 5 and p2 > 5, "p1 and p2 must be > 5" + + kx = _fft_indices(p1)[:, None] + ky = _fft_indices(p2)[None, :] + R = np.sqrt(kx**2 + ky**2) + + mag = np.exp(-0.5 * ((R - r0) / max(sigma, 1e-6)) ** 2) + + mag[0, 0] = 0.0 + + power = np.sum(mag**2) + if power > 0: + mag *= np.sqrt(total_power / power) + + spec = mag.astype(np.complex128) + + x = np.fft.ifft2(spec).real + return x + + +def gaussian_2d( + p1: int, + p2: int, + center: tuple = None, + sigma: float = 2.0, + k_freqs: int = None, +) -> np.ndarray: + """Generate 2D template with a single Gaussian, optionally band-limited. + + Args: + p1: height dimension + p2: width dimension + center: (cx, cy) center position + sigma: standard deviation of Gaussian + k_freqs: if not None, keep only the top k frequencies by power + + Returns: + template: (p1, p2) real-valued array + """ + if center is None: + center = (p1 / 2, p2 / 2) + cx, cy = center + x = np.arange(p1) + y = np.arange(p2) + X, Y = np.meshgrid(x, y, indexing="ij") + template = np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * sigma**2)) + if k_freqs is not None: + spectrum = np.fft.fft2(template) + power = np.abs(spectrum) ** 2 + power_flat = power.flatten() + kx_indices = np.arange(p1) + ky_indices = np.arange(p2) + KX, KY = np.meshgrid(kx_indices, ky_indices, indexing="ij") + all_indices = list(zip(KX.flatten(), KY.flatten())) + sorted_idx = np.argsort(-power_flat) + top_k_idx = sorted_idx[:k_freqs] + top_k_freqs = set([all_indices[i] for i in top_k_idx]) + mask = np.zeros((p1, p2), dtype=complex) + for kx, ky in top_k_freqs: + mask[kx, ky] = 1.0 + spectrum_masked = spectrum * mask + template = np.fft.ifft2(spectrum_masked).real + return template diff --git a/src/templates.py b/src/templates.py deleted file mode 100644 index f2d8de2..0000000 --- a/src/templates.py +++ /dev/null @@ -1,206 +0,0 @@ -import numpy as np -from skimage.transform import resize -from sklearn.datasets import fetch_openml -from sklearn.utils import shuffle - -from src.group_fourier_transform import ( - compute_group_inverse_fourier_transform, -) - - -def one_hot(p): - """One-hot encode an integer value in R^p.""" - vec = np.zeros(p) - vec[1] = 10 - - zeroth_freq = np.mean(vec) - vec = vec - zeroth_freq - return vec - - -def fixed_cn_template(group_size, fourier_coef_mags): - """Generate a fixed template for the 1D modular addition dataset. - - Parameters - ---------- - group_size : int - n in Cn. Number of elements in the 1D modular addition - fourier_coef_mags : list of float - Magnitudes of the Fourier coefficients to set. This list can have any length, and the - coefficients will be assigned to frequency modes in increasing order: - 0, 1, 2, ..., n-1 (and then their negative counterparts to ensure a real-valued template) - where 0 represents the zeroth frequency mode. - Returns - ------- - template : np.ndarray - A 1D array of shape (group_size,) representing the template. - """ - # Generate template array from Fourier spectrum - spectrum = np.zeros(group_size, dtype=complex) - - spectrum[0] = fourier_coef_mags[0] # Zeroth frequency component - fourier_coef_mags = fourier_coef_mags[1:] # Exclude zeroth frequency - - for i_mag, mag in enumerate(fourier_coef_mags): - mode = i_mag + 1 # Frequency mode starts from 1 - spectrum[mode] = mag - spectrum[-mode] = np.conj(mag) - print("Setting mode:", mode, "with magnitude:", mag) - - # Generate signal from spectrum - template = np.fft.ifft(spectrum).real - - zeroth_freq = np.mean(template) - template = template - zeroth_freq - - return template - - -def fixed_cnxcn_template(image_length, fourier_coef_mags): - """Generate a fixed template for the 2D modular addition dataset. - - Note: Since our input is a flattened matrix, we should un-flatten - the weights vectors to match the shape of the template when we visualize. - - Parameters - ---------- - image_length : int - image_length = n in Cn x Cn. Number of elements per dimension in the 2D modular addition - fourier_coef_mags : list of float - Magnitudes of the Fourier coefficients to set. This list can have any length, and the - coefficients will be assigned to frequency modes in the following order: - (0,0), (1,0), (0,1), (1,1), (2,0), (0,2), (2,2), (3,0), (0,3), (3,3), ... - (and then their negative counterparts to ensure a real-valued template) - where (i,j) represents the frequency mode with frequency i in the first dimension - - Returns - ------- - template : np.ndarray - A flattened 2D array of shape (image_length, image_length) representing the template. - After flattening, it will have shape (image_length*image_length,). - """ - # Generate template array from Fourier spectrum - spectrum = np.zeros((image_length, image_length), dtype=complex) - - spectrum[0, 0] = fourier_coef_mags[0] # Zeroth frequency component - fourier_coef_mags = fourier_coef_mags[1:] # Exclude zeroth frequency - - def mode_selector(i_mag): - i_mode = 1 + i_mag // 3 - mode_type = i_mag % 3 - if mode_type == 0: - return (i_mode, 0) - elif mode_type == 1: - return (0, i_mode) - else: - return (i_mode, i_mode) - - i_mag = 0 - while i_mag < len(fourier_coef_mags): - mode = mode_selector(i_mag) - - spectrum[mode[0], mode[1]] = fourier_coef_mags[i_mag] - spectrum[-mode[0], -mode[1]] = np.conj(fourier_coef_mags[i_mag]) - print("Setting mode:", mode, "with magnitude:", fourier_coef_mags[i_mag]) - i_mag += 1 - - # Generate signal from spectrum - template = np.fft.ifft2(spectrum).real - - template = template.flatten() - - zeroth_freq = np.mean(template) - template = template - zeroth_freq - - return template - - -def fixed_group_template(group, fourier_coef_diag_values): - """Generate a fixed template for a group, that has non-zero Fourier coefficients - only for a few irreps. - - Parameters - ---------- - group : Group (escnn object) - The group. - num_irreps : int - Number of irreps to set non-zero Fourier coefficients for. (Default is 3.) - - Returns - ------- - template : np.ndarray, shape=[group.order()] - The mean centered template. - """ - spectrum = [] - assert len(fourier_coef_diag_values) == len(group.irreps()), ( - f"Number of Fourier coef. magnitudes on the diagonal {len(fourier_coef_diag_values)} must match number of irreps {len(group.irreps())}" - ) - for i, irrep in enumerate(group.irreps()): - diag_values = np.full(irrep.size, fourier_coef_diag_values[i], dtype=float) - mat = np.zeros((irrep.size, irrep.size), dtype=float) - np.fill_diagonal(mat, diag_values) - print(f"mat for irrep {i} of dimension {irrep.size} is:\n {mat}\n") - - spectrum.append(mat) - - # Generate signal from spectrum - template = compute_group_inverse_fourier_transform(group, spectrum) - - zeroth_freq = np.mean(template) - template = template - zeroth_freq - - return template - - -def mnist_template(image_length, digit=0, sample_idx=0, random_state=42): - """Generate a template from the MNIST dataset, resized to p x p, for a specified digit. - - Parameters - ---------- - image_length : int - p in Z/pZ x Z/pZ. Number of elements per dimension in the 2D modular addition - digit : int, optional - The MNIST digit to use as a template (0-9). Default is 0. - sample_idx : int, optional - The index of the sample to use among the filtered digit images. Default is 0. - random_state : int, optional - Random seed for shuffling the digit images. Default is 42. - - Returns - ------- - template : np.ndarray - A flattened 2D array of shape (image_length, image_length) representing the template. - """ - # Load MNIST dataset - mnist = fetch_openml("mnist_784", version=1) - X = mnist.data.values - y = mnist.target.astype(int).values - - # Filter for the specified digit - X_digit = X[y == digit] - - if X_digit.shape[0] == 0: - raise ValueError(f"No samples found for digit {digit} in MNIST dataset.") - - # Shuffle and select the desired sample - X_digit = shuffle(X_digit, random_state=random_state) - if sample_idx >= X_digit.shape[0]: - raise IndexError( - f"sample_idx {sample_idx} is out of bounds for digit {digit} (found {X_digit.shape[0]} samples)." - ) - sample = X_digit[sample_idx].reshape(28, 28) - - # Resize to p x p - sample_resized = resize(sample, (image_length, image_length), anti_aliasing=True) - - # Normalize to [0, 1] - sample_resized = (sample_resized - np.min(sample_resized)) / ( - np.max(sample_resized) - np.min(sample_resized) - ) - - template = sample_resized.flatten() - - zeroth_freq = np.mean(template) - template = template - zeroth_freq - - return template diff --git a/src/utils.py b/src/utils.py deleted file mode 100644 index f287616..0000000 --- a/src/utils.py +++ /dev/null @@ -1,1737 +0,0 @@ -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.ticker import MaxNLocator - -### ----- VISUALIZATION FUNCTIONS ----- ### - - -def style_axes(ax, numyticks=5, numxticks=5, labelsize=24): - # Y-axis ticks - ax.tick_params( - axis="y", - which="both", - bottom=True, - top=False, - labelbottom=True, - left=True, - right=False, - labelleft=True, - direction="out", - length=7, - width=1.5, - pad=8, - labelsize=labelsize, - ) - ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks)) - - # X-axis ticks - ax.tick_params( - axis="x", - which="both", - bottom=True, - top=False, - labelbottom=True, - left=True, - right=False, - labelleft=True, - direction="out", - length=7, - width=1.5, - pad=8, - labelsize=labelsize, - ) - ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks)) - - ax.xaxis.offsetText.set_fontsize(20) - ax.grid() - - # Customize spines - for spine in ["top", "right"]: - ax.spines[spine].set_visible(False) - for spine in ["left", "bottom"]: - ax.spines[spine].set_linewidth(3) - - -def plot_train_val_loss( - train_loss_history, val_loss_history, save_path=None, show=True, xlabel="Step" -): - """ - Plot training and validation loss vs steps. - - Args: - train_loss_history: List of training loss values - val_loss_history: List of validation loss values - save_path: Optional path to save figure - show: Whether to display the plot - xlabel: Label for x-axis (e.g., 'Step' or 'Epoch') - """ - fig, ax = plt.subplots(1, 1, figsize=(10, 6)) - - steps = np.arange(len(train_loss_history)) - - ax.plot(steps, train_loss_history, lw=2, color="#1f77b4", label="Training Loss", alpha=0.7) - ax.plot(steps, val_loss_history, lw=2, color="#ff7f0e", label="Validation Loss") - - ax.set_xlabel(xlabel, fontsize=14) - ax.set_ylabel("Loss", fontsize=14) - ax.set_title("Training vs Validation Loss", fontsize=16) - ax.legend(fontsize=12) - ax.grid(True, alpha=0.3) - ax.set_yscale("log") # Log scale often helps see loss curves - - if save_path: - fig.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved to {save_path}") - - if show: - plt.show() - else: - plt.close(fig) - - return fig, ax - - -def plot_2d_signal( - signal_2d, - title="", - cmap="RdBu_r", - colorbar=True, -): - """Plot a 2D signal as a heatmap.""" - fig, ax = plt.subplots(1, 1, figsize=(10, 6)) - im = ax.imshow(signal_2d, cmap=cmap, aspect="equal", interpolation="nearest") - ax.set_title(title, fontsize=14) - ax.set_xlabel("y", fontsize=12) - ax.set_ylabel("x", fontsize=12) - if colorbar: - plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) - plt.tight_layout() - return fig, ax - - -def plot_2d_power_spectrum( - ax, - power_2d, - fx=None, - fy=None, - title="Power Spectrum", - cmap="viridis", - log_scale=True, - shift=True, -): - """Plot 2D power spectrum with proper frequency axes.""" - if log_scale: - power_plot = np.log10(power_2d + 1e-12) - title = f"{title} (log₁₀)" - else: - power_plot = power_2d - - # Optionally shift to center zero frequency - if shift: - power_plot = np.fft.fftshift(power_plot) - if fx is not None and fy is not None: - fx = np.fft.fftshift(fx) - fy = np.fft.fftshift(fy) - - # Set up extent for proper frequency axis labeling - if fx is not None and fy is not None: - extent = [fy.min(), fy.max(), fx.min(), fx.max()] - else: - extent = None - - im = ax.imshow( - power_plot, - cmap=cmap, - aspect="equal", - interpolation="nearest", - origin="lower", - extent=extent, - ) - ax.set_title(title, fontsize=14) - ax.set_xlabel("k_y (frequency)", fontsize=12) - ax.set_ylabel("k_x (frequency)", fontsize=12) - plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) - return im - - -### ----- POWER SPECTRUM FUNCTIONS ----- ### - - -def get_power_1d(points_1d): - """ - Compute 1D power spectrum using rfft (for real-valued inputs). - - Args: - points_1d: (p,) array - - Returns: - power: (p//2+1,) array of power values - freqs: frequency indices - """ - p = len(points_1d) - - # Perform 1D FFT - ft = np.fft.rfft(points_1d) - power = np.abs(ft) ** 2 / p - - # Handle conjugate symmetry for real signals - power = 2 * power.copy() - power[0] = power[0] / 2 # DC component - if p % 2 == 0: - power[-1] = power[-1] / 2 # Nyquist frequency - - freqs = np.fft.rfftfreq(p, 1.0) * p - - return power, freqs - - -def topk_template_freqs_1d(template_1d: np.ndarray, K: int, min_power: float = 1e-20): - """ - Return top-K frequency indices by power for 1D template. - - Args: - template_1d: 1D template array (p,) - K: Number of top frequencies to return - min_power: Minimum power threshold - - Returns: - List of frequency indices (as integers) - """ - power, _ = get_power_1d(template_1d) - mask = power > min_power - if not np.any(mask): - return [] - valid_power = power[mask] - valid_indices = np.flatnonzero(mask) - top_idx = valid_indices[np.argsort(valid_power)[::-1]][:K] - return top_idx.tolist() - - -def topk_template_freqs(template_2d: np.ndarray, K: int, min_power: float = 1e-20): - """ - Return top-K (kx, ky) rFFT2 bins by power from get_power_2d_adele(template_2d). - """ - freqs_u, freqs_v, power = get_power_2d_adele(template_2d) # power shape: (p1, p2//2 + 1) - shp = power.shape - flat = power.ravel() - mask = flat > min_power - if not np.any(mask): - return [] - top_idx = np.flatnonzero(mask)[np.argsort(flat[mask])[::-1]][:K] - kx, ky = np.unravel_index(top_idx, shp) - return list(zip(kx.tolist(), ky.tolist())) - - -def get_power_2d_adele(points, no_freq=False): - """ - Compute 2D power spectrum using rfft2 with proper symmetry handling. - - Args: - points: (M, N) array, the 2D signal - no_freq: if True, only return power (no frequency arrays) - - Returns: - freqs_u: frequency bins for rows (if no_freq=False) - freqs_v: frequency bins for columns (if no_freq=False) - power: 2D power spectrum (M, N//2 + 1) - """ - M, N = points.shape - - # Perform 2D rFFT - ft = np.fft.rfft2(points) - - # Power spectrum normalized by total number of samples - power = np.abs(ft) ** 2 / (M * N) - - # Construct weighting to handle real conjugate symmetry - weight = 2 * np.ones((M, N // 2 + 1)) - weight[0, 0] = 1 # handles DC component - weight[(M // 2 + 1) :, 0] = 0 # handles DC frequency in second axis - if M % 2 == 0: - weight[M // 2, 0] = 1 - if N % 2 == 0: - weight[(M // 2 + 1) :, N // 2] = 0 - weight[0, N // 2] = 1 - if (M % 2 == 0) and (N % 2 == 0): - weight[M // 2, N // 2] = 1 - - # Reweight power to account for redundancies - power = weight * power - - # Check Parseval's theorem - total_power = np.sum(power) - norm_squared = np.linalg.norm(points) ** 2 - if not np.isclose(total_power, norm_squared, rtol=1e-6): - print( - f"Warning: Total power {total_power:.3f} does not match norm squared {norm_squared:.3f}" - ) - - if no_freq: - return power - - # Frequency bins - freqs_u = np.fft.fftfreq(M) # full symmetric frequencies (rows) - freqs_v = np.fft.rfftfreq(N) # only non-negative frequencies (columns) - - return freqs_u, freqs_v, power - - -def compute_theoretical_loss_levels_2d(template_2d): - """ - Compute theoretical MSE loss levels based on template power spectrum. - - Returns both the initial loss (before learning) and final loss (fully converged). - The theory predicts step-wise loss reductions as each Fourier mode is learned. - - Args: - template_2d: 2D template array (p1, p2) - - Returns: - dict with: - 'initial': Expected MSE before any learning (= Var(template)) - 'final': Expected MSE when fully converged (~0) - 'levels': All intermediate loss plateaus - """ - p1, p2 = template_2d.shape - power = get_power_2d_adele(template_2d, no_freq=True) - - power_flat = power.flatten() - power_flat = np.sort(power_flat[power_flat > 1e-20])[::-1] # Descending - - coef = 1.0 / (p1 * p2) - - # Theory levels: cumulative tail sums - levels = [coef * np.sum(power_flat[k:]) for k in range(len(power_flat) + 1)] - - return { - "initial": levels[0] if levels else 0.0, # Before learning any mode - "final": 0.0, # When all modes are learned - "levels": levels, - } - - -def compute_theoretical_loss_levels_1d(template_1d): - """ - Compute theoretical MSE loss levels based on 1D template power spectrum. - - Args: - template_1d: 1D template array (p,) - - Returns: - dict with: - 'initial': Expected MSE before any learning - 'final': Expected MSE when fully converged (~0) - 'levels': All intermediate loss plateaus - """ - p = len(template_1d) - power, _ = get_power_1d(template_1d) - - power = np.sort(power[power > 1e-20])[::-1] # Descending - - coef = 1.0 / p - - # Theory levels: cumulative tail sums - levels = [coef * np.sum(power[k:]) for k in range(len(power) + 1)] - - return { - "initial": levels[0] if levels else 0.0, - "final": 0.0, - "levels": levels, - } - - -# Backward compatibility aliases -def compute_theoretical_final_loss_2d(template_2d): - """Returns expected initial loss (for setting convergence targets).""" - return compute_theoretical_loss_levels_2d(template_2d)["initial"] - - -def compute_theoretical_final_loss_1d(template_1d): - """Returns expected initial loss (for setting convergence targets).""" - return compute_theoretical_loss_levels_1d(template_1d)["initial"] - - -def _tracked_power_from_fft2(power2d, kx, ky, p1, p2): - """ - Sum power at (kx, ky) and its real-signal mirror (-kx, -ky). - - For real signals, the full FFT has conjugate symmetry, so power at (kx, ky) - and (-kx, -ky) are equal. This helper sums both for consistent power measurement. - - Args: - power2d: 2D power spectrum from fft2 (shape: p1, p2) - kx, ky: Frequency indices - p1, p2: Dimensions of the signal - - Returns: - float: Total power at this frequency (including mirror) - """ - i0, j0 = kx % p1, ky % p2 - i1, j1 = (-kx) % p1, (-ky) % p2 - if (i0, j0) == (i1, j1): - return float(power2d[i0, j0]) - return float(power2d[i0, j0] + power2d[i1, j1]) - - -def _squareish_grid(n): - """Compute nearly-square grid dimensions for n items.""" - c = int(np.ceil(np.sqrt(n))) - r = int(np.ceil(n / c)) - return r, c - - -def _fourier_mode_2d(p1: int, p2: int, kx: int, ky: int, phase: float = 0.0): - """Generate a 2D Fourier mode (cosine wave), normalized to [0, 1].""" - y = np.arange(p1)[:, None] - x = np.arange(p2)[None, :] - mode = np.cos(2 * np.pi * (ky * y / p1 + kx * x / p2) + phase) - mmin, mmax = mode.min(), mode.max() - return (mode - mmin) / (mmax - mmin) if mmax > mmin else mode - - -def _signed_k(k: int, n: int) -> int: - """Convert frequency index to signed representation (-n/2 to n/2).""" - return k if k <= n // 2 else k - n - - -def _pretty_k(k: int, n: int) -> str: - """Format frequency for display (handles Nyquist frequency with ± symbol).""" - if n % 2 == 0 and k == n // 2: - return rf"\pm{n // 2}" - return f"{_signed_k(k, n)}" - - -def _permutation_from_groups_with_dead( - dom_idx, phase, dom_power, l2, *, within="phase", dead_l2_thresh=1e-1 -): - """ - Create neuron permutation grouped by dominant frequency. - - Args: - dom_idx: Dominant frequency index for each neuron - phase: Phase at dominant frequency for each neuron - dom_power: Power at dominant frequency for each neuron - l2: L2 norm of each neuron's weights - within: How to order within groups ('phase', 'power', 'phase_power', 'none') - dead_l2_thresh: L2 threshold below which neurons are "dead" - - Returns: - perm: Permutation indices - ordered_keys: Ordered list of group keys (-1 for dead) - boundaries: Cumulative indices where groups end - """ - dead_mask = l2 < float(dead_l2_thresh) - groups = {} - for i, f in enumerate(dom_idx): - key = -1 if dead_mask[i] else int(f) - groups.setdefault(key, []).append(i) - - freq_keys = sorted([k for k in groups.keys() if k >= 0]) - ordered_keys = freq_keys + ([-1] if -1 in groups else []) - - perm, boundaries = [], [] - for f in ordered_keys: - idxs = groups[f] - if f == -1: - idxs = sorted(idxs, key=lambda i: l2[i]) - else: - if within == "phase" and phase is not None: - idxs = sorted(idxs, key=lambda i: (phase[i] + 2 * np.pi) % (2 * np.pi)) - elif within == "power" and dom_power is not None: - idxs = sorted(idxs, key=lambda i: -dom_power[i]) - elif within == "phase_power": - idxs = sorted( - idxs, key=lambda i: ((phase[i] + 2 * np.pi) % (2 * np.pi), -dom_power[i]) - ) - perm.extend(idxs) - boundaries.append(len(perm)) - - return np.array(perm, dtype=int), ordered_keys, boundaries - - -def plot_training_loss_with_theory( - loss_history, template_2d, p1, p2, x_values=None, x_label="Step", save_path=None, show=True -): - """ - Plot training loss with theoretical power spectrum lines. - - Args: - loss_history: List of loss values - template_2d: The 2D template array (p1, p2) - p1, p2: Dimensions - x_values: X-axis values (if None, uses indices 0, 1, 2, ...) - x_label: Label for x-axis (e.g., "Samples Seen", "Fraction of Space") - save_path: Optional path to save figure - show: Whether to display the plot - """ - fig, ax = plt.subplots(1, 1, figsize=(10, 6)) - - # Use provided x_values or default to indices - if x_values is None: - x_values = np.arange(len(loss_history)) - - # Plot loss - ax.plot(x_values, loss_history, lw=4, color="#1f77b4", label="Training Loss") - - # Compute power spectrum of template - x_freq, y_freq, power = get_power_2d_adele(template_2d) - power = power.flatten() - valid = power > 1e-20 - power = power[valid] - power = np.sort(power)[::-1] # Descending order - - # Plot theoretical lines (cumulative tail sums) - alpha_values = [np.sum(power[k:]) for k in range(len(power))] - coef = 1 / (p1 * p2) - for k, alpha in enumerate(alpha_values): - ax.axhline(y=coef * alpha, color="black", linestyle="--", linewidth=2, zorder=-2) - - ax.set_xlabel(x_label, fontsize=24) - ax.set_ylabel("Train Loss", fontsize=24) - - style_axes(ax) - ax.grid(False) - - plt.tight_layout() - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved loss plot to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig, ax - - -def plot_model_predictions_over_time( - model, - param_history, - X_data, - Y_data, - p1, - p2, - steps=None, - example_idx=None, - cmap="gray", - save_path=None, - show=False, -): - """ - Plot model predictions at different training steps vs ground truth. - - Args: - model: The trained model - param_history: List of parameter snapshots from training - X_data: Input tensor (N, k, p1*p2) - Y_data: Target tensor (N, p1*p2) - p1, p2: Dimensions - steps: List of epoch indices to plot (default: [1, 5, 10, final]) - example_idx: Index of example to visualize (default: random) - cmap: Colormap to use - save_path: Path to save figure - show: Whether to display the plot - """ - import torch - - # Default steps - if steps is None: - final_step = len(param_history) - 1 - steps = [1, min(5, final_step), min(10, final_step), final_step] - steps = sorted(list(set(steps))) # Remove duplicates - - # Random example if not specified - if example_idx is None: - example_idx = int(np.random.randint(len(Y_data))) - - device = next(model.parameters()).device - model.to(device).eval() - - # Ground truth - if Y_data.dim() == 3: - Y_data = Y_data[:, -1, :] # only final time step - with torch.no_grad(): - truth_2d = Y_data[example_idx].reshape(p1, p2).cpu().numpy() - - # Collect predictions at each step - preds = [] - for step in steps: - model.load_state_dict(param_history[step], strict=True) - with torch.no_grad(): - x = X_data[example_idx : example_idx + 1].to(device) - pred_2d = model(x) - if pred_2d.dim() == 3: - pred_2d = pred_2d[:, -1, :] # only final time step - - pred_2d = pred_2d.reshape(p1, p2).detach().cpu().numpy() - - preds.append(pred_2d) - - # Shared color scale based on ground truth - vmin = np.min(truth_2d) - vmax = np.max(truth_2d) - - # Plot: rows = [Prediction, Target], cols = time steps - fig, axes = plt.subplots(2, len(steps), figsize=(3.5 * len(steps), 6), layout="constrained") - - # Handle case where there's only one step - if len(steps) == 1: - axes = axes.reshape(2, 1) - - for col, (step, pred_2d) in enumerate(zip(steps, preds)): - # Prediction - im = axes[0, col].imshow(pred_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin="upper") - axes[0, col].set_title(f"Epoch {step}", fontsize=12) - axes[0, col].set_xticks([]) - axes[0, col].set_yticks([]) - - # Target (same for all columns) - axes[1, col].imshow(truth_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin="upper") - axes[1, col].set_xticks([]) - axes[1, col].set_yticks([]) - - axes[0, 0].set_ylabel("Prediction", fontsize=14) - axes[1, 0].set_ylabel("Target", fontsize=14) - - # Single shared colorbar on the right - fig.colorbar(im, ax=axes, location="right", shrink=0.9, pad=0.02).set_label( - "Value", fontsize=12 - ) - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved predictions plot to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig, axes - - -def plot_model_predictions_over_time_1d( - model, - param_history, - X_data, - Y_data, - p, - steps=None, - example_idx=None, - save_path=None, - show=False, -): - """ - Plot model predictions at different training steps vs ground truth (1D version). - - Args: - model: The trained model - param_history: List of parameter snapshots from training - X_data: Input tensor (N, k, p) - Y_data: Target tensor (N, p) - p: Dimension - steps: List of epoch indices to plot (default: [1, 5, 10, final]) - example_idx: Index of example to visualize (default: random) - save_path: Path to save figure - show: Whether to display the plot - """ - import torch - - # Default steps - if steps is None: - final_step = len(param_history) - 1 - steps = [1, min(5, final_step), min(10, final_step), final_step] - steps = sorted(list(set(steps))) - - # Random example if not specified - if example_idx is None: - example_idx = int(np.random.randint(len(Y_data))) - - device = next(model.parameters()).device - model.to(device).eval() - - # Ground truth - if Y_data.dim() == 3: - Y_data = Y_data[:, -1, :] # only final time step - with torch.no_grad(): - truth_1d = Y_data[example_idx].cpu().numpy() - - # Collect predictions at each step - preds = [] - for step in steps: - model.load_state_dict(param_history[step], strict=True) - with torch.no_grad(): - x = X_data[example_idx : example_idx + 1].to(device) - pred = model(x) - if pred.dim() == 3: - pred = pred[:, -1, :] # only final time step - pred_1d = pred.squeeze().detach().cpu().numpy() - preds.append(pred_1d) - - # Plot: rows = [Prediction, Target], cols = time steps - fig, axes = plt.subplots(2, len(steps), figsize=(3.5 * len(steps), 4), layout="constrained") - - # Handle case where there's only one step - if len(steps) == 1: - axes = axes.reshape(2, 1) - - x = np.arange(p) - - for col, (step, pred_1d) in enumerate(zip(steps, preds)): - # Prediction - axes[0, col].plot(x, pred_1d, "b-", lw=2) - axes[0, col].set_title(f"Epoch {step}", fontsize=12) - axes[0, col].set_ylim( - truth_1d.min() - 0.1 * np.abs(truth_1d.min()), - truth_1d.max() + 0.1 * np.abs(truth_1d.max()), - ) - axes[0, col].set_xticks([]) - axes[0, col].grid(True, alpha=0.3) - - # Target (same for all columns) - axes[1, col].plot(x, truth_1d, "k-", lw=2) - axes[1, col].set_ylim( - truth_1d.min() - 0.1 * np.abs(truth_1d.min()), - truth_1d.max() + 0.1 * np.abs(truth_1d.max()), - ) - axes[1, col].set_xticks([]) - axes[1, col].grid(True, alpha=0.3) - - axes[0, 0].set_ylabel("Prediction", fontsize=14) - axes[1, 0].set_ylabel("Target", fontsize=14) - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved predictions plot to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig, axes - - -def plot_prediction_power_spectrum_over_time( - model, - param_history, - X_data, - Y_data, - template_2d, - p1, - p2, - loss_history, - param_save_indices=None, - num_freqs_to_track=10, - checkpoint_indices=None, - num_samples=100, - save_path=None, - show=False, -): - """ - Plot training loss with power spectrum analysis of predictions over time. - - Creates a two-panel plot: - - Top: Training loss with colored bands for theory lines - - Bottom: Power in tracked frequencies over time (computed at ALL saved checkpoints) - - Args: - model: The trained model - param_history: List of parameter snapshots (includes epoch 0) - X_data: Input tensor (N, k, p1*p2) - Y_data: Target tensor (N, p1*p2) - used to compute loss history - template_2d: The 2D template array - p1, p2: Dimensions - loss_history: List of loss values over training steps/epochs - param_save_indices: List of step/epoch numbers where params were saved (for x-axis alignment) - num_freqs_to_track: Number of top frequencies to track - checkpoint_indices: (deprecated/unused) - now analyzes ALL checkpoints - num_samples: Number of samples to average for power computation - save_path: Path to save figure - show: Whether to display the plot - """ - import torch - from matplotlib.ticker import FormatStrFormatter - from tqdm import tqdm - - device = next(model.parameters()).device - - # Identify top-K frequencies from template - tracked_freqs = topk_template_freqs(template_2d, K=num_freqs_to_track) - template_power_2d = get_power_2d_adele(template_2d, no_freq=True) - target_powers = {(kx, ky): template_power_2d[kx, ky] for (kx, ky) in tracked_freqs} - - # Analyze ALL saved parameter checkpoints for full temporal resolution - T = len(param_history) - - steps_analysis = list(range(len(param_history))) # Analyze ALL saved params - - # Get the actual step/epoch numbers for x-axis plotting - if param_save_indices is not None: - actual_steps = param_save_indices # All the actual step numbers - else: - actual_steps = list(range(len(param_history))) # If None, indices = steps - - # Track average output power at those frequencies over training - powers_over_time = {freq: [] for freq in tracked_freqs} - - print(f" Analyzing {len(steps_analysis)} checkpoints for power spectrum...") - - with torch.no_grad(): - for step in tqdm(steps_analysis, desc=" Computing power spectra", leave=False): - model.load_state_dict(param_history[step], strict=True) - model.eval() - - # Get predictions for a batch - outputs_flat = ( - model(X_data[:num_samples].to(device)).detach().cpu().numpy() - ) # (num_samples, p1*p2) - - # Compute power spectrum for each sample, then average - powers_batch = [] - for i in range(outputs_flat.shape[0]): - if outputs_flat.ndim == 3: - out_2d = outputs_flat[i][-1, :] # only final time step - else: - out_2d = outputs_flat[i] - out_2d = out_2d.reshape(p1, p2) - power_i = get_power_2d_adele(out_2d, no_freq=True) # (p1, p2//2+1) - powers_batch.append(power_i) - avg_power = np.mean(powers_batch, axis=0) # (p1, p2//2+1) - - # Record power at each tracked frequency - for kx, ky in tracked_freqs: - powers_over_time[(kx, ky)].append(avg_power[kx, ky]) - - # Convert lists to arrays - for freq in tracked_freqs: - powers_over_time[freq] = np.array(powers_over_time[freq]) - - if param_save_indices is None: - # Assume params were saved at every step (old behavior) - loss_epochs = np.arange(len(param_history)) - loss_history_subset = loss_history - else: - # Use the provided indices - loss_epochs = np.array(param_save_indices) - # Extract only the loss values at those indices - loss_history_subset = [loss_history[i] for i in param_save_indices] - - # --- Create the plot --- - colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) - - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True) - fig.subplots_adjust(left=0.12, right=0.98, top=0.96, bottom=0.10, hspace=0.12) - - # --- Top panel: Training loss with theory bands --- - ax1.plot(loss_epochs, loss_history_subset, lw=4, color="#1f77b4", label="Training Loss") - - # Compute power spectrum of template for theory lines - _, _, power = get_power_2d_adele(template_2d) - power_flat = np.sort(power.flatten()[power.flatten() > 1e-20])[::-1] - - # Theory levels (cumulative tail sums) - alpha_values = np.array([np.sum(power_flat[k:]) for k in range(len(power_flat))]) - coef = 1.0 / (p1 * p2) - y_levels = coef * alpha_values # strictly decreasing - - # Shade horizontal bands between successive theory lines - n_bands = min(len(tracked_freqs), len(y_levels) - 1) - for i in range(n_bands): - y_top = y_levels[i] - y_bot = y_levels[i + 1] - ax1.axhspan(y_bot, y_top, facecolor=colors[i], alpha=0.15, zorder=-3) - - # Draw the black theory lines - for y in y_levels[: n_bands + 1]: - ax1.axhline(y=y, color="black", linestyle="--", linewidth=2, zorder=-2) - - ax1.set_ylabel("Theory Loss Levels", fontsize=20) - ax1.set_ylim(y_levels[n_bands], y_levels[0] * 1.1) - style_axes(ax1) - ax1.grid(False) - ax1.tick_params(labelbottom=False) - - # --- Bottom panel: Tracked mode power over time --- - for i, (kx, ky) in enumerate(tracked_freqs): - ax2.plot( - actual_steps, # Use actual step/epoch numbers, not indices - powers_over_time[(kx, ky)], - color=colors[i], - lw=3, - label=f"({kx},{ky})", - ) - ax2.axhline( - target_powers[(kx, ky)], - color=colors[i], - linestyle="dotted", - linewidth=2, - alpha=0.5, - ) - - ax2.set_xlabel("Steps", fontsize=20) - ax2.set_ylabel("Power in Prediction", fontsize=20) - ax2.grid(True, alpha=0.3) - style_axes(ax2) - ax2.yaxis.set_major_formatter(FormatStrFormatter("%.1f")) - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved power spectrum plot to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig, (ax1, ax2), powers_over_time, tracked_freqs - - -def plot_prediction_power_spectrum_over_time_1d( - model, - param_history, - X_data, - Y_data, - template_1d, - p, - loss_history, - param_save_indices=None, - num_freqs_to_track=10, - checkpoint_indices=None, - num_samples=100, - save_path=None, - show=False, -): - """ - Plot training loss with power spectrum analysis of predictions over time (1D version). - - Creates a two-panel plot: - - Top: Training loss with colored bands for theory lines - - Bottom: Power in tracked frequencies over time (computed at ALL saved checkpoints) - - Args: - model: The trained model - param_history: List of parameter snapshots (includes epoch 0) - X_data: Input tensor (N, k, p) - Y_data: Target tensor (N, p) - template_1d: The 1D template array (p,) - p: Dimension of the template - loss_history: List of loss values over training steps/epochs - param_save_indices: List of step/epoch numbers where params were saved - num_freqs_to_track: Number of top frequencies to track - checkpoint_indices: (deprecated/unused) - now analyzes ALL checkpoints - num_samples: Number of samples to average for power computation - save_path: Path to save figure - show: Whether to display the plot - """ - import torch - from matplotlib.ticker import FormatStrFormatter - from tqdm import tqdm - - device = next(model.parameters()).device - - # Identify top-K frequencies from template - tracked_freqs = topk_template_freqs_1d(template_1d, K=num_freqs_to_track) - template_power, _ = get_power_1d(template_1d) - target_powers = {k: template_power[k] for k in tracked_freqs} - - # Analyze ALL saved parameter checkpoints - T = len(param_history) - steps_analysis = list(range(len(param_history))) - - # Get the actual step/epoch numbers for x-axis - if param_save_indices is not None: - actual_steps = param_save_indices - else: - actual_steps = list(range(len(param_history))) - - # Track average output power at those frequencies over training - powers_over_time = {freq: [] for freq in tracked_freqs} - - print(f" Analyzing {len(steps_analysis)} checkpoints for power spectrum (1D)...") - - with torch.no_grad(): - for step in tqdm(steps_analysis, desc=" Computing power spectra", leave=False): - model.load_state_dict(param_history[step], strict=True) - model.eval() - - # Get predictions for a batch - outputs_flat = ( - model(X_data[:num_samples].to(device)).detach().cpu().numpy() - ) # (num_samples, p) - - # Compute power spectrum for each sample, then average - powers_batch = [] - for i in range(outputs_flat.shape[0]): - if outputs_flat.ndim == 3: - out_1d = outputs_flat[i, -1, :] # only final time step - else: - out_1d = outputs_flat[i] - power_i, _ = get_power_1d(out_1d) - powers_batch.append(power_i) - avg_power = np.mean(powers_batch, axis=0) # (p//2+1,) - - # Record power at each tracked frequency - for k in tracked_freqs: - powers_over_time[k].append(avg_power[k]) - - # Convert lists to arrays - for freq in tracked_freqs: - powers_over_time[freq] = np.array(powers_over_time[freq]) - - if param_save_indices is None: - loss_epochs = np.arange(len(param_history)) - loss_history_subset = loss_history - else: - loss_epochs = np.array(param_save_indices) - loss_history_subset = [loss_history[i] for i in param_save_indices] - - # --- Create the plot --- - colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) - - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True) - fig.subplots_adjust(left=0.12, right=0.98, top=0.96, bottom=0.10, hspace=0.12) - - # --- Top panel: Training loss with theory bands --- - ax1.plot(loss_epochs, loss_history_subset, lw=4, color="#1f77b4", label="Training Loss") - - # Compute power spectrum of template for theory lines - power, _ = get_power_1d(template_1d) - power_sorted = np.sort(power[power > 1e-20])[::-1] - - # Theory levels (cumulative tail sums) - alpha_values = np.array([np.sum(power_sorted[k:]) for k in range(len(power_sorted))]) - coef = 1.0 / p - y_levels = coef * alpha_values # strictly decreasing - - # Shade horizontal bands between successive theory lines - n_bands = min(len(tracked_freqs), len(y_levels) - 1) - for i in range(n_bands): - y_top = y_levels[i] - y_bot = y_levels[i + 1] - ax1.axhspan(y_bot, y_top, facecolor=colors[i], alpha=0.15, zorder=-3) - - # Draw the black theory lines - for y in y_levels[: n_bands + 1]: - ax1.axhline(y=y, color="black", linestyle="--", linewidth=2, zorder=-2) - - ax1.set_ylabel("Theory Loss Levels", fontsize=20) - ax1.set_ylim(y_levels[n_bands], y_levels[0] * 1.1) - style_axes(ax1) - ax1.grid(False) - ax1.tick_params(labelbottom=False) - - # --- Bottom panel: Tracked mode power over time --- - for i, k in enumerate(tracked_freqs): - ax2.plot(actual_steps, powers_over_time[k], color=colors[i], lw=3, label=f"k={k}") - ax2.axhline( - target_powers[k], - color=colors[i], - linestyle="dotted", - linewidth=2, - alpha=0.5, - ) - - ax2.set_xlabel("Steps", fontsize=20) - ax2.set_ylabel("Power in Prediction", fontsize=20) - ax2.grid(True, alpha=0.3) - ax2.legend(fontsize=10, loc="best", ncol=2) - style_axes(ax2) - ax2.yaxis.set_major_formatter(FormatStrFormatter("%.1f")) - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=150) - print(f" ✓ Saved power spectrum plot to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig, (ax1, ax2), powers_over_time, tracked_freqs - - -def plot_fourier_modes_reference( - tracked_freqs, - colors, - p1, - p2, - save_path=None, - save_individual=False, - individual_dir=None, - show=False, -): - """ - Create a reference visualization of tracked Fourier modes. - - Generates a stacked vertical image showing all tracked frequency modes - with colored borders matching the power spectrum analysis. - - Args: - tracked_freqs: List of (kx, ky) tuples for tracked frequencies - colors: Array of colors for each frequency (from plt.cm.tab10 or similar) - p1, p2: Dimensions of the template - save_path: Path to save the stacked visualization - save_individual: Whether to also save individual mode images - individual_dir: Directory for individual mode images (if save_individual=True) - show: Whether to display the plot - - Returns: - fig: The matplotlib figure - """ - from pathlib import Path - - import matplotlib.gridspec as gridspec - import matplotlib.patheffects as pe - - # --- Save individual mode images (optional) --- - if save_individual and individual_dir is not None: - individual_dir = Path(individual_dir) - individual_dir.mkdir(exist_ok=True) - - for i, (kx, ky) in enumerate(tracked_freqs): - img = _fourier_mode_2d(p1, p2, kx, ky) - - fig_ind, ax = plt.subplots(figsize=(3.2, 2.2)) - ax.imshow(img, cmap="RdBu_r", origin="upper") - ax.set_xticks([]) - ax.set_yticks([]) - - # Colored border - for side in ("left", "right", "top", "bottom"): - ax.spines[side].set_edgecolor(colors[i]) - ax.spines[side].set_linewidth(8) - - # Frequency label - kx_label = _pretty_k(kx, p2) - ky_label = _pretty_k(ky, p1) - ax.text( - 0.5, - 0.5, - f"$k=({kx_label},{ky_label})$", - color=colors[i], - fontsize=25, - fontweight="bold", - ha="center", - va="center", - transform=ax.transAxes, - path_effects=[pe.withStroke(linewidth=3, foreground="white", alpha=0.8)], - ) - - plt.tight_layout() - - # Save with signed indices in filename - kx_signed, ky_signed = _signed_k(kx, p2), _signed_k(ky, p1) - base = f"mode_{i:03d}_kx{kx}_ky{ky}_signed_{kx_signed}_{ky_signed}" - fig_ind.savefig(individual_dir / f"{base}.png", dpi=300, bbox_inches="tight") - np.save(individual_dir / f"{base}.npy", img) - plt.close(fig_ind) - - print(f" ✓ Saved {len(tracked_freqs)} individual mode images to {individual_dir}") - - # --- Create stacked vertical visualization --- - n = len(tracked_freqs) - - # Panel geometry and spacing - panel_h_in = 2.2 - gap_h_in = 0.35 # whitespace between rows - fig_w_in = 4.6 - fig_h_in = n * panel_h_in + (n - 1) * gap_h_in - - fig = plt.figure(figsize=(fig_w_in, fig_h_in), dpi=150) - - # Rows alternate: [panel, gap, panel, gap, ..., panel] - rows = 2 * n - 1 - height_ratios = [] - for i in range(n): - height_ratios.append(panel_h_in) - if i < n - 1: - height_ratios.append(gap_h_in) - - # Layout: image on LEFT, label on RIGHT - gs = gridspec.GridSpec( - nrows=rows, - ncols=2, - width_ratios=[1.0, 0.46], - height_ratios=height_ratios, - wspace=0.0, - hspace=0.0, - ) - - for i, (kx, ky) in enumerate(tracked_freqs): - r = 2 * i # even rows are content; odd rows are spacers - - # Image axis (left) - ax_img = fig.add_subplot(gs[r, 0]) - img = _fourier_mode_2d(p1, p2, kx, ky) - ax_img.imshow(img, cmap="RdBu_r", origin="upper", aspect="equal") - ax_img.set_xticks([]) - ax_img.set_yticks([]) - - # Colored border around the image - for side in ("left", "right", "top", "bottom"): - ax_img.spines[side].set_edgecolor(colors[i]) - ax_img.spines[side].set_linewidth(8) - - # Label axis (right) - ax_label = fig.add_subplot(gs[r, 1]) - ax_label.set_axis_off() - kx_label = _pretty_k(kx, p2) - ky_label = _pretty_k(ky, p1) - ax_label.text( - 0.0, - 0.5, - f"$k=({kx_label},{ky_label})$", - color=colors[i], - fontsize=45, - fontweight="bold", - ha="left", - va="center", - transform=ax_label.transAxes, - path_effects=[pe.withStroke(linewidth=3, foreground="white", alpha=0.8)], - ) - - # Adjust to prevent clipping of thick borders - fig.subplots_adjust(left=0.02, right=0.98, top=0.985, bottom=0.015) - - if save_path: - fig.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0.12) - print(f" ✓ Saved Fourier modes reference to {save_path}") - - if show: - plt.show() - else: - plt.close() - - return fig - - -def plot_wout_neuron_specialization( - param_history, - tracked_freqs, - colors, - p1, - p2, - steps=None, - dead_thresh_l2=0.25, - save_dir=None, - show=False, -): - """ - Visualize W_out neurons colored by their dominant tracked frequency. - - Creates grid visualizations of output weight neurons at different training steps, - with colored borders indicating which Fourier mode each neuron is tuned to. - - Args: - param_history: List of parameter snapshots from training - tracked_freqs: List of (kx, ky) tuples for tracked frequencies - colors: Array of colors for each frequency (from plt.cm.tab10 or similar) - p1, p2: Dimensions of the template - steps: List of epoch indices to plot (default: [1, 5, final]) - dead_thresh_l2: L2 norm threshold below which neurons are considered "dead" - save_dir: Directory to save figures (Path object) - show: Whether to display the plots - - Returns: - List of figure objects - """ - from pathlib import Path - - import matplotlib.cm as cm - import matplotlib.gridspec as gridspec - from matplotlib.colors import Normalize - from matplotlib.patches import Patch - - # Default steps - if steps is None: - final_step = len(param_history) - 1 - steps = [1, min(5, final_step), final_step] - steps = sorted(list(set(steps))) - - # Get dimensions - W0 = param_history[steps[0]]["W_out"].detach().cpu().numpy().T # (H, D) - H, D = W0.shape - assert p1 * p2 == D, f"p1*p2 ({p1 * p2}) must equal D ({D})." - - # Compute global color limits across all steps - vmin, vmax = np.inf, -np.inf - for step in steps: - W = param_history[step]["W_out"].detach().cpu().numpy().T - vmin = min(vmin, W.min()) - vmax = max(vmax, W.max()) - - # Grid layout - R_ner, C_ner = _squareish_grid(H) - tile_w, tile_h = 2, 2 # inches per neuron tile - figsize = (C_ner * tile_w, R_ner * tile_h) - - heat_cmap = "RdBu_r" - border_lw = 5.0 - dead_color = (0.6, 0.6, 0.6, 1.0) - - figures = [] - - # Create one figure per time step - for step in steps: - W = param_history[step]["W_out"].detach().cpu().numpy().T # (H, D) - - # Determine dominant frequency for each neuron - dom_idx = np.empty(H, dtype=int) - l2 = np.linalg.norm(W, axis=1) - dead_mask = l2 < dead_thresh_l2 - - for j in range(H): - m = W[j].reshape(p1, p2) - F = np.fft.fft2(m) - P = (F.conj() * F).real - tp = [_tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs] - dom_idx[j] = int(np.argmax(tp)) - - # Assign colors - edge_colors = colors[dom_idx].copy() - edge_colors[dead_mask] = dead_color - - # Build figure - fig = plt.figure(figsize=figsize) - gs = gridspec.GridSpec(R_ner, C_ner, figure=fig, wspace=0.06, hspace=0.06) - - # Plot neuron tiles - for j in range(R_ner * C_ner): - ax = fig.add_subplot(gs[j // C_ner, j % C_ner]) - if j < H: - m = W[j].reshape(p1, p2) - ax.imshow(m, vmin=vmin, vmax=vmax, origin="lower", aspect="equal", cmap=heat_cmap) - # Colored border - ec = edge_colors[j] - for sp in ax.spines.values(): - sp.set_edgecolor(ec) - sp.set_linewidth(border_lw) - else: - ax.axis("off") - - ax.set_xticks([]) - ax.set_yticks([]) - - if save_dir: - save_path = Path(save_dir) / f"wout_neurons_epoch_{step:04d}.pdf" - fig.savefig(save_path, bbox_inches="tight", dpi=200) - print(f" ✓ Saved W_out visualization for epoch {step}") - - if show: - plt.show() - else: - plt.close() - - figures.append(fig) - - # Create standalone colorbar figure - fig_cb = plt.figure(figsize=(6, 1.2)) - ax_cb = fig_cb.add_axes([0.1, 0.35, 0.8, 0.3]) - norm = Normalize(vmin=vmin, vmax=vmax) - sm = cm.ScalarMappable(norm=norm, cmap=heat_cmap) - cbar = fig_cb.colorbar(sm, cax=ax_cb, orientation="horizontal") - cbar.set_label("Weight value", fontsize=12) - - if save_dir: - save_path = Path(save_dir) / "wout_colorbar.pdf" - fig_cb.savefig(save_path, bbox_inches="tight", dpi=150) - print(" ✓ Saved colorbar") - - if show: - plt.show() - else: - plt.close() - - figures.append(fig_cb) - - # Create standalone legend figure - fig_legend = plt.figure(figsize=(6, 2.0)) - ax_leg = fig_legend.add_subplot(111) - ax_leg.axis("off") - - # Colored edge patches (matching tile borders) - handles = [ - Patch(facecolor="white", edgecolor=colors[i], linewidth=2.5, label=f"k={tracked_freqs[i]}") - for i in range(len(tracked_freqs)) - ] - handles.append(Patch(facecolor="white", edgecolor=dead_color, linewidth=2.5, label="dead")) - - ax_leg.legend( - handles=handles, - ncol=min(4, len(handles)), - frameon=True, - loc="center", - title="Dominant frequency", - fontsize=10, - ) - - if save_dir: - save_path = Path(save_dir) / "wout_legend.pdf" - fig_legend.savefig(save_path, bbox_inches="tight", dpi=150) - print(" ✓ Saved legend") - - if show: - plt.show() - else: - plt.close() - - figures.append(fig_legend) - - return figures - - -def plot_wout_neuron_specialization_1d( - param_history, - tracked_freqs, - colors, - p, - steps=None, - dead_thresh_l2=0.25, - save_dir=None, - show=False, -): - """ - Visualize W_out neurons colored by their dominant tracked frequency (1D version). - - Creates visualizations of output weight neurons at different training steps, - with colored borders indicating which Fourier mode each neuron is tuned to. - For 1D, neurons are shown as line plots. - - Args: - param_history: List of parameter snapshots from training - tracked_freqs: List of frequency indices (integers) - colors: Array of colors for each frequency - p: Dimension of the template - steps: List of epoch indices to plot (default: [1, 5, final]) - dead_thresh_l2: L2 norm threshold below which neurons are considered "dead" - save_dir: Directory to save figures (Path object) - show: Whether to display the plots - - Returns: - List of figure objects - """ - from pathlib import Path - - from matplotlib.patches import Patch - - def tracked_power_from_fft(power1d, k): - """Get power at frequency k.""" - return float(power1d[k]) - - # Default steps - if steps is None: - final_step = len(param_history) - 1 - steps = [1, min(5, final_step), final_step] - steps = sorted(list(set(steps))) - - # Get dimensions - W0 = param_history[steps[0]]["W_out"].detach().cpu().numpy().T # (H, p) - H, D = W0.shape - assert p == D, f"p ({p}) must equal D ({D})." - - figures = [] - - # Create one figure per time step - for step in steps: - W = param_history[step]["W_out"].detach().cpu().numpy().T # (H, p) - - # Determine dominant frequency for each neuron - dom_idx = np.empty(H, dtype=int) - l2 = np.linalg.norm(W, axis=1) - dead_mask = l2 < dead_thresh_l2 - - for j in range(H): - neuron_weights = W[j] - power, _ = get_power_1d(neuron_weights) - tp = [tracked_power_from_fft(power, k) for k in tracked_freqs] - dom_idx[j] = int(np.argmax(tp)) - - # Assign colors - edge_colors = colors[dom_idx].copy() - edge_colors[dead_mask] = (0.6, 0.6, 0.6, 1.0) - - # Create grid of subplots - ncols = min(6, H) - nrows = int(np.ceil(H / ncols)) - - fig, axes = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 1.5 * nrows), squeeze=False) - - x = np.arange(p) - - for j in range(nrows * ncols): - row = j // ncols - col = j % ncols - ax = axes[row, col] - - if j < H: - # Plot neuron weights - ax.plot(x, W[j], color=edge_colors[j], lw=1.5) - ax.set_xlim(0, p - 1) - ax.set_ylim(W.min(), W.max()) - ax.set_xticks([]) - ax.set_yticks([]) - - # Colored border - for spine in ax.spines.values(): - spine.set_edgecolor(edge_colors[j]) - spine.set_linewidth(3) - else: - ax.axis("off") - - plt.tight_layout() - - if save_dir: - save_path = Path(save_dir) / f"wout_neurons_1d_epoch_{step:04d}.pdf" - fig.savefig(save_path, bbox_inches="tight", dpi=200) - print(f" ✓ Saved W_out 1D visualization for epoch {step}") - - if show: - plt.show() - else: - plt.close() - - figures.append(fig) - - # Create legend figure - fig_legend = plt.figure(figsize=(8, 2.0)) - ax_leg = fig_legend.add_subplot(111) - ax_leg.axis("off") - - handles = [ - Patch(facecolor="white", edgecolor=colors[i], linewidth=2.5, label=f"k={tracked_freqs[i]}") - for i in range(len(tracked_freqs)) - ] - handles.append( - Patch(facecolor="white", edgecolor=(0.6, 0.6, 0.6, 1.0), linewidth=2.5, label="dead") - ) - - ax_leg.legend( - handles=handles, - ncol=min(5, len(handles)), - frameon=True, - loc="center", - title="Dominant frequency", - fontsize=10, - ) - - if save_dir: - save_path = Path(save_dir) / "wout_legend_1d.pdf" - fig_legend.savefig(save_path, bbox_inches="tight", dpi=150) - print(" ✓ Saved legend") - - if show: - plt.show() - else: - plt.close() - - figures.append(fig_legend) - - return figures - - -def analyze_wout_frequency_dominance(state_dict, tracked_freqs, p1, p2): - """ - Analyze W_out to find dominant frequency for each neuron. - - Args: - state_dict: Model parameters (expects 'W_out' key) - tracked_freqs: List of (kx, ky) tuples - p1, p2: Template dimensions - - Returns: - dom_idx: Dominant frequency index for each neuron - phase: Phase at dominant frequency for each neuron - dom_power: Power at dominant frequency for each neuron - l2: L2 norm of each neuron's weights - """ - Wo = state_dict["W_out"].detach().cpu().numpy() # (p, H) - W = Wo.T # (H, p) - H, D = W.shape - assert D == p1 * p2 - - dom_idx = np.empty(H, dtype=int) - dom_pow = np.empty(H, dtype=float) - phase = np.empty(H, dtype=float) - l2 = np.linalg.norm(W, axis=1) - - for j in range(H): - m = W[j].reshape(p1, p2) - F = np.fft.fft2(m) - P = (F.conj() * F).real - # Power at tracked frequencies - tp = [_tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs] - jj = int(np.argmax(tp)) - dom_idx[j] = jj - # Phase at representative bin - i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2 - phase[j] = np.angle(F[i0, j0]) - dom_pow[j] = tp[jj] - - return dom_idx, phase, dom_pow, l2 - - -def plot_wmix_frequency_structure( - param_history, - tracked_freqs, - colors, - p1, - p2, - steps=None, - within_group_order="phase", - dead_l2_thresh=0.1, - save_path=None, - show=False, -): - """ - Visualize W_mix structure grouped by W_out frequency specialization. - - Creates heatmaps of W_mix reordered to show block structure based on - which Fourier mode each neuron is tuned to in W_out. - - Args: - param_history: List of parameter snapshots - tracked_freqs: List of (kx, ky) frequency tuples - colors: Array of colors for each frequency - p1, p2: Template dimensions - steps: List of epoch indices to plot (default: [1, 5, final]) - within_group_order: How to order neurons within each frequency group - ('phase', 'power', 'phase_power', 'none') - dead_l2_thresh: L2 threshold for dead neurons - save_path: Path to save figure - show: Whether to display plot - - Returns: - fig, axes - """ - from matplotlib.patches import Rectangle - - # Default steps - if steps is None: - final_step = len(param_history) - 1 - steps = [1, min(5, final_step), final_step] - steps = sorted(list(set(steps))) - - # Labels for frequencies - tracked_labels = [ - ("DC" if (kx, ky) == (0, 0) else f"({kx},{ky})") for (kx, ky) in tracked_freqs - ] - - # Analyze and reorder for each step - Wmix_perm_list = [] - group_info_list = [] - - for s in steps: - sd = param_history[s] - - # Analyze W_out - dom_idx, phase, dom_power, l2 = analyze_wout_frequency_dominance(sd, tracked_freqs, p1, p2) - - # Get W_mix (fallback to W_h for compatibility) - if "W_mix" in sd: - M = sd["W_mix"].detach().cpu().numpy() - elif "W_h" in sd: - M = sd["W_h"].detach().cpu().numpy() - else: - raise KeyError("Neither 'W_mix' nor 'W_h' found in state dict.") - - # Compute permutation - perm, group_keys, boundaries = _permutation_from_groups_with_dead( - dom_idx, phase, dom_power, l2, within=within_group_order, dead_l2_thresh=dead_l2_thresh - ) - - # Reorder - M_perm = M[perm][:, perm] - Wmix_perm_list.append(M_perm) - group_info_list.append((group_keys, boundaries)) - - # Shared color limits - vmax = max(np.max(np.abs(M)) for M in Wmix_perm_list) - vmin = -vmax if vmax > 0 else 0.0 - - # Create figure - n = len(steps) - fig, axes = plt.subplots(1, n, figsize=(3.8 * n, 3.8), constrained_layout=True) - if n == 1: - axes = [axes] - - cmap = "RdBu_r" - dead_gray = "0.35" - - im = None - for j, (s, M_perm) in enumerate(zip(steps, Wmix_perm_list)): - ax = axes[j] - im = ax.imshow( - M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect="equal", interpolation="nearest" - ) - - ax.set_yticks([]) - ax.tick_params(axis="x", bottom=False) - - group_keys, boundaries = group_info_list[j] - - # Draw separators between groups - for b in boundaries[:-1]: - ax.axhline(b - 0.5, color="k", lw=0.9, alpha=0.65) - ax.axvline(b - 0.5, color="k", lw=0.9, alpha=0.65) - - # Draw colored boxes around frequency groups - starts = [0] + boundaries[:-1] - ends = [b - 1 for b in boundaries] - for kk, s0, e0 in zip(group_keys, starts, ends): - if kk == -1: # Skip dead neurons - continue - size = e0 - s0 + 1 - rect = Rectangle( - (s0 - 0.5, s0 - 0.5), - width=size, - height=size, - fill=False, - linewidth=2.0, - edgecolor=colors[kk], - alpha=0.95, - joinstyle="miter", - ) - ax.add_patch(rect) - - # Add labels at top - centers = [(s + e) / 2.0 for s, e in zip(starts, ends)] - sizes = [e - s + 1 for s, e in zip(starts, ends)] - - labels = [] - label_colors = [] - for kk, nn in zip(group_keys, sizes): - if kk == -1: - labels.append(f"DEAD\n(n={nn})") - label_colors.append(dead_gray) - else: - labels.append(f"{tracked_labels[kk]}\n(n={nn})") - label_colors.append(colors[kk]) - - ax.set_xticks(centers) - ax.set_xticklabels(labels, fontsize=11, ha="center") - ax.tick_params( - axis="x", bottom=False, top=True, labelbottom=False, labeltop=True, labelsize=11 - ) - for lbl, clr in zip(ax.get_xticklabels(), label_colors): - lbl.set_color(clr) - - ax.set_xlabel(f"Epoch {s}", fontsize=18, labelpad=8) - - # Shared colorbar - cbar = fig.colorbar(im, ax=axes, shrink=1.0, pad=0.012, aspect=18) - cbar.ax.tick_params(labelsize=11) - cbar.set_label("Weight value", fontsize=12) - - if save_path: - plt.savefig(save_path, bbox_inches="tight", dpi=200) - print(" ✓ Saved W_mix structure plot") - - if show: - plt.show() - else: - plt.close() - - return fig, axes diff --git a/src/viz.py b/src/viz.py new file mode 100644 index 0000000..391fcd6 --- /dev/null +++ b/src/viz.py @@ -0,0 +1,954 @@ +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.ticker import MaxNLocator + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def style_axes(ax, numyticks=5, numxticks=5, labelsize=24): + # Y-axis ticks + ax.tick_params( + axis="y", + which="both", + bottom=True, + top=False, + labelbottom=True, + left=True, + right=False, + labelleft=True, + direction="out", + length=7, + width=1.5, + pad=8, + labelsize=labelsize, + ) + ax.yaxis.set_major_locator(MaxNLocator(nbins=numyticks)) + + # X-axis ticks + ax.tick_params( + axis="x", + which="both", + bottom=True, + top=False, + labelbottom=True, + left=True, + right=False, + labelleft=True, + direction="out", + length=7, + width=1.5, + pad=8, + labelsize=labelsize, + ) + ax.xaxis.set_major_locator(MaxNLocator(nbins=numxticks)) + + ax.xaxis.offsetText.set_fontsize(20) + ax.grid() + + # Customize spines + for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + for spine in ["left", "bottom"]: + ax.spines[spine].set_linewidth(3) + + +def _permutation_from_groups_with_dead( + dom_idx, phase, dom_power, l2, *, within="phase", dead_l2_thresh=1e-1 +): + """Create neuron permutation grouped by dominant frequency. + + Args: + dom_idx: Dominant frequency index for each neuron + phase: Phase at dominant frequency for each neuron + dom_power: Power at dominant frequency for each neuron + l2: L2 norm of each neuron's weights + within: How to order within groups ('phase', 'power', 'phase_power', 'none') + dead_l2_thresh: L2 threshold below which neurons are "dead" + + Returns: + perm: Permutation indices + ordered_keys: Ordered list of group keys (-1 for dead) + boundaries: Cumulative indices where groups end + """ + dead_mask = l2 < float(dead_l2_thresh) + groups = {} + for i, f in enumerate(dom_idx): + key = -1 if dead_mask[i] else int(f) + groups.setdefault(key, []).append(i) + + freq_keys = sorted([k for k in groups.keys() if k >= 0]) + ordered_keys = freq_keys + ([-1] if -1 in groups else []) + + perm, boundaries = [], [] + for f in ordered_keys: + idxs = groups[f] + if f == -1: + idxs = sorted(idxs, key=lambda i: l2[i]) + else: + if within == "phase" and phase is not None: + idxs = sorted(idxs, key=lambda i: (phase[i] + 2 * np.pi) % (2 * np.pi)) + elif within == "power" and dom_power is not None: + idxs = sorted(idxs, key=lambda i: -dom_power[i]) + elif within == "phase_power": + idxs = sorted( + idxs, key=lambda i: ((phase[i] + 2 * np.pi) % (2 * np.pi), -dom_power[i]) + ) + perm.extend(idxs) + boundaries.append(len(perm)) + + return np.array(perm, dtype=int), ordered_keys, boundaries + + +def analyze_wout_frequency_dominance(state_dict, tracked_freqs, p1, p2): + """Analyze W_out to find dominant frequency for each neuron. + + Args: + state_dict: Model parameters (expects 'W_out' key) + tracked_freqs: List of (kx, ky) tuples + p1, p2: Template dimensions + + Returns: + dom_idx: Dominant frequency index for each neuron + phase: Phase at dominant frequency for each neuron + dom_power: Power at dominant frequency for each neuron + l2: L2 norm of each neuron's weights + """ + import src.power as power + + Wo = state_dict["W_out"].detach().cpu().numpy() # (p, H) + W = Wo.T # (H, p) + H, D = W.shape + assert D == p1 * p2 + + dom_idx = np.empty(H, dtype=int) + dom_pow = np.empty(H, dtype=float) + phase = np.empty(H, dtype=float) + l2 = np.linalg.norm(W, axis=1) + + for j in range(H): + m = W[j].reshape(p1, p2) + F = np.fft.fft2(m) + P = (F.conj() * F).real + tp = [power._tracked_power_from_fft2(P, kx, ky, p1, p2) for (kx, ky) in tracked_freqs] + jj = int(np.argmax(tp)) + dom_idx[j] = jj + i0, j0 = tracked_freqs[jj][0] % p1, tracked_freqs[jj][1] % p2 + phase[j] = np.angle(F[i0, j0]) + dom_pow[j] = tp[jj] + + return dom_idx, phase, dom_pow, l2 + + +# --------------------------------------------------------------------------- +# Plotting functions +# --------------------------------------------------------------------------- + + +def plot_signal_2d( + signal_2d, + title="", + cmap="RdBu_r", + colorbar=True, +): + """Plot a 2D signal as a heatmap.""" + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + im = ax.imshow(signal_2d, cmap=cmap, aspect="equal", interpolation="nearest") + ax.set_title(title, fontsize=14) + ax.set_xlabel("y", fontsize=12) + ax.set_ylabel("x", fontsize=12) + if colorbar: + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + plt.tight_layout() + return fig, ax + + +def plot_train_loss_with_theory( + loss_history, template_2d, p1, p2, x_values=None, x_label="Step", save_path=None, show=True +): + """Plot training loss with theoretical power spectrum lines. + + Args: + loss_history: List of loss values + template_2d: The 2D template array (p1, p2) + p1, p2: Dimensions + x_values: X-axis values (if None, uses indices 0, 1, 2, ...) + x_label: Label for x-axis (e.g., "Samples Seen", "Fraction of Space") + save_path: Optional path to save figure + show: Whether to display the plot + """ + import src.power as power + + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + + if x_values is None: + x_values = np.arange(len(loss_history)) + + ax.plot(x_values, loss_history, lw=4, color="#1f77b4", label="Training Loss") + + x_freq, y_freq, pwr = power.get_power_2d(template_2d) + pwr = pwr.flatten() + valid = pwr > 1e-20 + pwr = pwr[valid] + pwr = np.sort(pwr)[::-1] + + alpha_values = [np.sum(pwr[k:]) for k in range(len(pwr))] + coef = 1 / (p1 * p2) + for k, alpha in enumerate(alpha_values): + ax.axhline(y=coef * alpha, color="black", linestyle="--", linewidth=2, zorder=-2) + + ax.set_xlabel(x_label, fontsize=24) + ax.set_ylabel("Train Loss", fontsize=24) + + style_axes(ax) + ax.grid(False) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + print(f" ✓ Saved loss plot to {save_path}") + + if show: + plt.show() + else: + plt.close() + + return fig, ax + + +def plot_predictions_2d( + model, + param_history, + X_data, + Y_data, + p1, + p2, + steps=None, + example_idx=None, + cmap="gray", + save_path=None, + show=False, +): + """Plot model predictions at different training steps vs ground truth (2D). + + Args: + model: The trained model + param_history: List of parameter snapshots from training + X_data: Input tensor (N, k, p1*p2) + Y_data: Target tensor (N, p1*p2) + p1, p2: Dimensions + steps: List of epoch indices to plot + example_idx: Index of example to visualize + cmap: Colormap to use + save_path: Path to save figure + show: Whether to display the plot + """ + import torch + + if steps is None: + final_step = len(param_history) - 1 + steps = [1, min(5, final_step), min(10, final_step), final_step] + steps = sorted(list(set(steps))) + + if example_idx is None: + example_idx = int(np.random.randint(len(Y_data))) + + device = next(model.parameters()).device + model.to(device).eval() + + if Y_data.dim() == 3: + Y_data = Y_data[:, -1, :] + with torch.no_grad(): + truth_2d = Y_data[example_idx].reshape(p1, p2).cpu().numpy() + + preds = [] + for step in steps: + model.load_state_dict(param_history[step], strict=True) + with torch.no_grad(): + x = X_data[example_idx : example_idx + 1].to(device) + pred_2d = model(x) + if pred_2d.dim() == 3: + pred_2d = pred_2d[:, -1, :] + pred_2d = pred_2d.reshape(p1, p2).detach().cpu().numpy() + preds.append(pred_2d) + + vmin = np.min(truth_2d) + vmax = np.max(truth_2d) + + fig, axes = plt.subplots(2, len(steps), figsize=(3.5 * len(steps), 6), layout="constrained") + if len(steps) == 1: + axes = axes.reshape(2, 1) + + for col, (step, pred_2d) in enumerate(zip(steps, preds)): + im = axes[0, col].imshow(pred_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin="upper") + axes[0, col].set_title(f"Epoch {step}", fontsize=12) + axes[0, col].set_xticks([]) + axes[0, col].set_yticks([]) + + axes[1, col].imshow(truth_2d, vmin=vmin, vmax=vmax, cmap=cmap, origin="upper") + axes[1, col].set_xticks([]) + axes[1, col].set_yticks([]) + + axes[0, 0].set_ylabel("Prediction", fontsize=14) + axes[1, 0].set_ylabel("Target", fontsize=14) + + fig.colorbar(im, ax=axes, location="right", shrink=0.9, pad=0.02).set_label( + "Value", fontsize=12 + ) + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + print(f" ✓ Saved predictions plot to {save_path}") + + if show: + plt.show() + else: + plt.close() + + return fig, axes + + +def plot_predictions_1d( + model, + param_history, + X_data, + Y_data, + p, + steps=None, + example_idx=None, + save_path=None, + show=False, +): + """Plot model predictions at different training steps vs ground truth (1D). + + Args: + model: The trained model + param_history: List of parameter snapshots from training + X_data: Input tensor (N, k, p) + Y_data: Target tensor (N, p) + p: Dimension + steps: List of epoch indices to plot + example_idx: Index of example to visualize + save_path: Path to save figure + show: Whether to display the plot + """ + import torch + + if steps is None: + final_step = len(param_history) - 1 + steps = [1, min(5, final_step), min(10, final_step), final_step] + steps = sorted(list(set(steps))) + + if example_idx is None: + example_idx = int(np.random.randint(len(Y_data))) + + device = next(model.parameters()).device + model.to(device).eval() + + if Y_data.dim() == 3: + Y_data = Y_data[:, -1, :] + with torch.no_grad(): + truth_1d = Y_data[example_idx].cpu().numpy() + + preds = [] + for step in steps: + model.load_state_dict(param_history[step], strict=True) + with torch.no_grad(): + x = X_data[example_idx : example_idx + 1].to(device) + pred = model(x) + if pred.dim() == 3: + pred = pred[:, -1, :] + pred_1d = pred.squeeze().detach().cpu().numpy() + preds.append(pred_1d) + + fig, axes = plt.subplots(2, len(steps), figsize=(3.5 * len(steps), 4), layout="constrained") + if len(steps) == 1: + axes = axes.reshape(2, 1) + + x = np.arange(p) + + for col, (step, pred_1d) in enumerate(zip(steps, preds)): + axes[0, col].plot(x, pred_1d, "b-", lw=2) + axes[0, col].set_title(f"Epoch {step}", fontsize=12) + axes[0, col].set_ylim( + truth_1d.min() - 0.1 * np.abs(truth_1d.min()), + truth_1d.max() + 0.1 * np.abs(truth_1d.max()), + ) + axes[0, col].set_xticks([]) + axes[0, col].grid(True, alpha=0.3) + + axes[1, col].plot(x, truth_1d, "k-", lw=2) + axes[1, col].set_ylim( + truth_1d.min() - 0.1 * np.abs(truth_1d.min()), + truth_1d.max() + 0.1 * np.abs(truth_1d.max()), + ) + axes[1, col].set_xticks([]) + axes[1, col].grid(True, alpha=0.3) + + axes[0, 0].set_ylabel("Prediction", fontsize=14) + axes[1, 0].set_ylabel("Target", fontsize=14) + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + print(f" ✓ Saved predictions plot to {save_path}") + + if show: + plt.show() + else: + plt.close() + + return fig, axes + + +def plot_power_1d( + model, + param_history, + X_data, + Y_data, + template_1d, + p, + loss_history, + param_save_indices=None, + num_freqs_to_track=10, + checkpoint_indices=None, + num_samples=100, + save_path=None, + show=False, +): + """Plot training loss with power spectrum analysis of predictions over time (1D). + + Creates a two-panel plot: + - Top: Training loss with colored bands for theory lines + - Bottom: Power in tracked frequencies over time + + Args: + model: The trained model + param_history: List of parameter snapshots + X_data: Input tensor (N, k, p) + Y_data: Target tensor (N, p) + template_1d: The 1D template array (p,) + p: Dimension of the template + loss_history: List of loss values + param_save_indices: List of step/epoch numbers where params were saved + num_freqs_to_track: Number of top frequencies to track + checkpoint_indices: (deprecated/unused) + num_samples: Number of samples to average for power computation + save_path: Path to save figure + show: Whether to display the plot + """ + import torch + from matplotlib.ticker import FormatStrFormatter + from tqdm import tqdm + + import src.power as power + + device = next(model.parameters()).device + + tracked_freqs = power.topk_template_freqs_1d(template_1d, K=num_freqs_to_track) + template_power, _ = power.get_power_1d(template_1d) + target_powers = {k: template_power[k] for k in tracked_freqs} + + T = len(param_history) + steps_analysis = list(range(len(param_history))) + + if param_save_indices is not None: + actual_steps = param_save_indices + else: + actual_steps = list(range(len(param_history))) + + powers_over_time = {freq: [] for freq in tracked_freqs} + + print(f" Analyzing {len(steps_analysis)} checkpoints for power spectrum (1D)...") + + with torch.no_grad(): + for step in tqdm(steps_analysis, desc=" Computing power spectra", leave=False): + model.load_state_dict(param_history[step], strict=True) + model.eval() + + outputs_flat = model(X_data[:num_samples].to(device)).detach().cpu().numpy() + + powers_batch = [] + for i in range(outputs_flat.shape[0]): + if outputs_flat.ndim == 3: + out_1d = outputs_flat[i, -1, :] + else: + out_1d = outputs_flat[i] + power_i, _ = power.get_power_1d(out_1d) + powers_batch.append(power_i) + avg_power = np.mean(powers_batch, axis=0) + + for k in tracked_freqs: + powers_over_time[k].append(avg_power[k]) + + for freq in tracked_freqs: + powers_over_time[freq] = np.array(powers_over_time[freq]) + + if param_save_indices is None: + loss_epochs = np.arange(len(param_history)) + loss_history_subset = loss_history + else: + loss_epochs = np.array(param_save_indices) + loss_history_subset = [loss_history[i] for i in param_save_indices] + + colors = plt.cm.tab10(np.linspace(0, 1, len(tracked_freqs))) + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True) + fig.subplots_adjust(left=0.12, right=0.98, top=0.96, bottom=0.10, hspace=0.12) + + ax1.plot(loss_epochs, loss_history_subset, lw=4, color="#1f77b4", label="Training Loss") + + pwr, _ = power.get_power_1d(template_1d) + power_sorted = np.sort(pwr[pwr > 1e-20])[::-1] + + alpha_values = np.array([np.sum(power_sorted[k:]) for k in range(len(power_sorted))]) + coef = 1.0 / p + y_levels = coef * alpha_values + + n_bands = min(len(tracked_freqs), len(y_levels) - 1) + for i in range(n_bands): + y_top = y_levels[i] + y_bot = y_levels[i + 1] + ax1.axhspan(y_bot, y_top, facecolor=colors[i], alpha=0.15, zorder=-3) + + for y in y_levels[: n_bands + 1]: + ax1.axhline(y=y, color="black", linestyle="--", linewidth=2, zorder=-2) + + ax1.set_ylabel("Theory Loss Levels", fontsize=20) + ax1.set_ylim(y_levels[n_bands], y_levels[0] * 1.1) + style_axes(ax1) + ax1.grid(False) + ax1.tick_params(labelbottom=False) + + for i, k in enumerate(tracked_freqs): + ax2.plot(actual_steps, powers_over_time[k], color=colors[i], lw=3, label=f"k={k}") + ax2.axhline( + target_powers[k], + color=colors[i], + linestyle="dotted", + linewidth=2, + alpha=0.5, + ) + + ax2.set_xlabel("Steps", fontsize=20) + ax2.set_ylabel("Power in Prediction", fontsize=20) + ax2.grid(True, alpha=0.3) + ax2.legend(fontsize=10, loc="best", ncol=2) + style_axes(ax2) + ax2.yaxis.set_major_formatter(FormatStrFormatter("%.1f")) + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + print(f" ✓ Saved power spectrum plot to {save_path}") + + if show: + plt.show() + else: + plt.close() + + return fig, (ax1, ax2), powers_over_time, tracked_freqs + + +def plot_wmix_structure( + param_history, + tracked_freqs, + colors, + p1, + p2, + steps=None, + within_group_order="phase", + dead_l2_thresh=0.1, + save_path=None, + show=False, +): + """Visualize W_mix structure grouped by W_out frequency specialization. + + Args: + param_history: List of parameter snapshots + tracked_freqs: List of (kx, ky) frequency tuples + colors: Array of colors for each frequency + p1, p2: Template dimensions + steps: List of epoch indices to plot + within_group_order: How to order neurons within each frequency group + dead_l2_thresh: L2 threshold for dead neurons + save_path: Path to save figure + show: Whether to display plot + """ + from matplotlib.patches import Rectangle + + if steps is None: + final_step = len(param_history) - 1 + steps = [1, min(5, final_step), final_step] + steps = sorted(list(set(steps))) + + tracked_labels = [ + ("DC" if (kx, ky) == (0, 0) else f"({kx},{ky})") for (kx, ky) in tracked_freqs + ] + + Wmix_perm_list = [] + group_info_list = [] + + for s in steps: + sd = param_history[s] + dom_idx, phase, dom_power, l2 = analyze_wout_frequency_dominance(sd, tracked_freqs, p1, p2) + + if "W_mix" in sd: + M = sd["W_mix"].detach().cpu().numpy() + elif "W_h" in sd: + M = sd["W_h"].detach().cpu().numpy() + else: + raise KeyError("Neither 'W_mix' nor 'W_h' found in state dict.") + + perm, group_keys, boundaries = _permutation_from_groups_with_dead( + dom_idx, phase, dom_power, l2, within=within_group_order, dead_l2_thresh=dead_l2_thresh + ) + + M_perm = M[perm][:, perm] + Wmix_perm_list.append(M_perm) + group_info_list.append((group_keys, boundaries)) + + vmax = max(np.max(np.abs(M)) for M in Wmix_perm_list) + vmin = -vmax if vmax > 0 else 0.0 + + n = len(steps) + fig, axes = plt.subplots(1, n, figsize=(3.8 * n, 3.8), constrained_layout=True) + if n == 1: + axes = [axes] + + cmap = "RdBu_r" + dead_gray = "0.35" + + im = None + for j, (s, M_perm) in enumerate(zip(steps, Wmix_perm_list)): + ax = axes[j] + im = ax.imshow( + M_perm, cmap=cmap, vmin=vmin, vmax=vmax, aspect="equal", interpolation="nearest" + ) + + ax.set_yticks([]) + ax.tick_params(axis="x", bottom=False) + + group_keys, boundaries = group_info_list[j] + + for b in boundaries[:-1]: + ax.axhline(b - 0.5, color="k", lw=0.9, alpha=0.65) + ax.axvline(b - 0.5, color="k", lw=0.9, alpha=0.65) + + starts = [0] + boundaries[:-1] + ends = [b - 1 for b in boundaries] + for kk, s0, e0 in zip(group_keys, starts, ends): + if kk == -1: + continue + size = e0 - s0 + 1 + rect = Rectangle( + (s0 - 0.5, s0 - 0.5), + width=size, + height=size, + fill=False, + linewidth=2.0, + edgecolor=colors[kk], + alpha=0.95, + joinstyle="miter", + ) + ax.add_patch(rect) + + centers = [(s + e) / 2.0 for s, e in zip(starts, ends)] + sizes = [e - s + 1 for s, e in zip(starts, ends)] + + labels = [] + label_colors = [] + for kk, nn in zip(group_keys, sizes): + if kk == -1: + labels.append(f"DEAD\n(n={nn})") + label_colors.append(dead_gray) + else: + labels.append(f"{tracked_labels[kk]}\n(n={nn})") + label_colors.append(colors[kk]) + + ax.set_xticks(centers) + ax.set_xticklabels(labels, fontsize=11, ha="center") + ax.tick_params( + axis="x", bottom=False, top=True, labelbottom=False, labeltop=True, labelsize=11 + ) + for lbl, clr in zip(ax.get_xticklabels(), label_colors): + lbl.set_color(clr) + + ax.set_xlabel(f"Epoch {s}", fontsize=18, labelpad=8) + + cbar = fig.colorbar(im, ax=axes, shrink=1.0, pad=0.012, aspect=18) + cbar.ax.tick_params(labelsize=11) + cbar.set_label("Weight value", fontsize=12) + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=200) + print(" ✓ Saved W_mix structure plot") + + if show: + plt.show() + else: + plt.close() + + return fig, axes + + +def plot_predictions_group( + model, + param_hist, + X_eval, + Y_eval, + group_order: int, + checkpoint_indices: list, + save_path: str = None, + num_samples: int = 5, + group_label: str = "Group", +): + """Plot model predictions vs targets at different training checkpoints. + + Args: + model: Trained model + param_hist: List of parameter snapshots + X_eval: Input evaluation tensor (N, k, group_order) + Y_eval: Target evaluation tensor (N, group_order) + group_order: Order of the group + checkpoint_indices: Indices into param_hist to visualize + save_path: Path to save the plot + num_samples: Number of samples to show + group_label: Human-readable label for the group (used in plot title) + """ + import torch + + n_checkpoints = len(checkpoint_indices) + + fig, axes = plt.subplots( + num_samples, n_checkpoints, figsize=(4 * n_checkpoints, 3 * num_samples) + ) + if num_samples == 1: + axes = axes.reshape(1, -1) + if n_checkpoints == 1: + axes = axes.reshape(-1, 1) + + sample_indices = np.random.choice( + len(X_eval), size=min(num_samples, len(X_eval)), replace=False + ) + + for col, ckpt_idx in enumerate(checkpoint_indices): + model.load_state_dict(param_hist[ckpt_idx]) + model.eval() + + with torch.no_grad(): + outputs = model(X_eval[sample_indices]) + outputs_np = outputs.cpu().numpy() + targets_np = Y_eval[sample_indices].cpu().numpy() + + for row, (output, target) in enumerate(zip(outputs_np, targets_np)): + ax = axes[row, col] + x_axis = np.arange(group_order) + + ax.bar(x_axis - 0.15, target, width=0.3, label="Target", alpha=0.7, color="#2ecc71") + ax.bar(x_axis + 0.15, output, width=0.3, label="Output", alpha=0.7, color="#e74c3c") + + if row == 0: + ax.set_title(f"Checkpoint {ckpt_idx}") + if col == 0: + ax.set_ylabel(f"Sample {sample_indices[row]}") + if row == num_samples - 1: + ax.set_xlabel("Group element") + if row == 0 and col == n_checkpoints - 1: + ax.legend(loc="upper right", fontsize=8) + + ax.set_xticks(x_axis) + ax.grid(True, alpha=0.3) + + plt.suptitle(f"{group_label} Predictions vs Targets Over Training", fontsize=14) + plt.tight_layout() + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + plt.close() + + +def plot_power_group( + model, + param_hist, + param_save_indices, + X_eval, + template: np.ndarray, + group, + k: int, + optimizer: str, + init_scale: float, + save_path: str = None, + group_label: str = "Group", +): + """Plot power spectrum of model outputs vs template over training. + + Uses GroupPower from src/power.py for template power and model_power_over_time + for model output power over training checkpoints. + + Args: + model: Trained model + param_hist: List of parameter snapshots + param_save_indices: List mapping param_hist index to epoch number + X_eval: Input evaluation tensor + template: Template array (group_order,) + group: escnn group object + k: Sequence length + optimizer: Optimizer name + init_scale: Initialization scale + save_path: Path to save the plot + group_label: Human-readable label for the group + """ + import src.power as power + + group_name = "group" + irreps = group.irreps() + n_irreps = len(irreps) + + template_power_obj = power.GroupPower(template, group=group) + template_power = template_power_obj.power + + print(f" Template power spectrum: {template_power}") + print(" (These are dim^2 * diag_value^2 / |G| for each irrep)") + + model_powers, steps = power.model_power_over_time( + group_name, model, param_hist, X_eval, group=group + ) + epoch_numbers = [param_save_indices[min(s, len(param_save_indices) - 1)] for s in steps] + + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + top_k = min(5, n_irreps) + top_irrep_indices = np.argsort(template_power)[::-1][:top_k] + + colors_line = plt.cm.tab10(np.linspace(0, 1, top_k)) + + valid_mask = np.array(epoch_numbers) > 0 + valid_epochs = np.array(epoch_numbers)[valid_mask] + valid_model_powers = model_powers[valid_mask, :] + + # Plot 1: Linear scales + ax = axes[0] + for i, irrep_idx in enumerate(top_irrep_indices): + power_values = model_powers[:, irrep_idx] + ax.plot( + epoch_numbers, + power_values, + "-", + lw=2, + color=colors_line[i], + label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", + ) + ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) + ax.set_xlabel("Epoch") + ax.set_ylabel("Power") + ax.set_title("Linear Scales", fontsize=12) + ax.legend(loc="upper left", fontsize=7) + ax.grid(True, alpha=0.3) + + # Plot 2: Log x-axis only + ax = axes[1] + for i, irrep_idx in enumerate(top_irrep_indices): + power_values = valid_model_powers[:, irrep_idx] + ax.plot( + valid_epochs, + power_values, + "-", + lw=2, + color=colors_line[i], + label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", + ) + ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) + ax.set_xscale("log") + ax.set_xlabel("Epoch (log scale)") + ax.set_ylabel("Power") + ax.set_title("Log X-axis", fontsize=12) + ax.legend(loc="upper left", fontsize=7) + ax.grid(True, alpha=0.3) + + # Plot 3: Log-log scales + ax = axes[2] + for i, irrep_idx in enumerate(top_irrep_indices): + power_values = valid_model_powers[:, irrep_idx] + power_mask = power_values > 0 + if np.any(power_mask): + ax.plot( + valid_epochs[power_mask], + power_values[power_mask], + "-", + lw=2, + color=colors_line[i], + label=f"Irrep {irrep_idx} (dim={irreps[irrep_idx].size})", + ) + if template_power[irrep_idx] > 0: + ax.axhline(template_power[irrep_idx], linestyle="--", alpha=0.5, color=colors_line[i]) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Epoch (log scale)") + ax.set_ylabel("Power (log scale)") + ax.set_title("Log-Log Scales", fontsize=12) + ax.legend(loc="upper left", fontsize=7) + ax.grid(True, alpha=0.3) + + fig.suptitle( + f"{group_label} Power Evolution Over Training (k={k}, {optimizer}, init={init_scale:.0e})", + fontsize=14, + fontweight="bold", + ) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, bbox_inches="tight", dpi=150) + plt.close() + + +def plot_irreps(group, show=False): + """Plot the irreducible representations (irreps) of the group. + + Parameters + ---------- + group : class instance + The group for which the irreps are being plotted. + show : bool, optional + Whether to display the plot immediately. + """ + FONT_SIZES = {"title": 30, "axes_label": 30, "tick_label": 30, "legend": 15} + + irreps = group.irreps() + group_elements = group.elements + + num_irreps = len(irreps) + fig, axs = plt.subplots(1, num_irreps, figsize=(3 * num_irreps, 4), squeeze=False) + axs = axs[0] + + for i, irrep in enumerate(irreps): + matrices = [irrep(g) for g in group_elements] + matrices = np.array(matrices) + + if matrices.ndim == 1 or (matrices.ndim == 2 and matrices.shape[1] == 1): + axs[i].plot(range(len(group_elements)), matrices.real, marker="o", label="Re") + if np.any(np.abs(matrices.imag) > 1e-10): + axs[i].plot(range(len(group_elements)), matrices.imag, marker="x", label="Im") + axs[i].set_title(f"Irrep {i}: {str(irrep)} (dim=1)") + axs[i].set_xlabel("Group element idx") + axs[i].set_ylabel("Irrep value") + axs[i].legend() + else: + d = matrices.shape[1] + num_group_elements = len(group_elements) + num_irrep_entries = d * d + irrep_matrix_entries = matrices.real.reshape(num_group_elements, num_irrep_entries) + im = axs[i].imshow(irrep_matrix_entries, aspect="auto", cmap="viridis") + axs[i].set_title(f"Irrep {i}: {str(irrep)} (size={d}x{d})") + axs[i].set_xlabel("Flattened Irreps") + axs[i].set_ylabel("Irrep(g)") + plt.colorbar(im, ax=axs[i]) + fig.suptitle( + "Irreducible Representations (matrix values for all group elements)", + fontsize=FONT_SIZES["title"], + ) + plt.tight_layout() + if show: + plt.show() + return fig diff --git a/test/test_bal_datasets.py b/test/test_bal_datasets.py deleted file mode 100644 index 14b67f7..0000000 --- a/test/test_bal_datasets.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Tests for src.datasets module.""" - -import numpy as np -import pytest -import torch - -from src.datasets import ( - cn_dataset, - cnxcn_dataset, - group_dataset, - move_dataset_to_device_and_flatten, -) - - -class TestCnDataset: - """Tests for cn_dataset function.""" - - def test_output_shape(self): - """Test that output shapes are correct.""" - group_size = 7 - template = np.random.randn(group_size) - - X, Y = cn_dataset(template) - - n_samples = group_size**2 - assert X.shape == (n_samples, 2, group_size), f"X shape mismatch: {X.shape}" - assert Y.shape == (n_samples, group_size), f"Y shape mismatch: {Y.shape}" - - def test_modular_addition_property(self): - """Test that Y is the rolled template by (a+b) mod p.""" - group_size = 5 - template = np.arange(group_size).astype(float) # [0, 1, 2, 3, 4] - - X, Y = cn_dataset(template) - - # Check a specific case: a=1, b=2 -> q=(1+2)%5=3 - # Index = a * group_size + b = 1 * 5 + 2 = 7 - idx = 1 * group_size + 2 - expected_y = np.roll(template, 3) # rolled by 3 - np.testing.assert_allclose(Y[idx], expected_y) - - def test_covers_all_pairs(self): - """Test that all pairs (a, b) are covered.""" - group_size = 4 - template = np.random.randn(group_size) - - X, Y = cn_dataset(template) - - # Should have exactly group_size^2 samples - assert X.shape[0] == group_size**2 - - -class TestCnxcnDataset: - """Tests for cnxcn_dataset function.""" - - def test_output_shape(self): - """Test that output shapes are correct.""" - image_length = 4 - template = np.random.randn(image_length * image_length) - - X, Y = cnxcn_dataset(template) - - n_samples = image_length**4 - n_features = image_length * image_length - assert X.shape == (n_samples, 2, n_features), f"X shape mismatch: {X.shape}" - assert Y.shape == (n_samples, n_features), f"Y shape mismatch: {Y.shape}" - - def test_covers_all_combinations(self): - """Test that all combinations are covered.""" - image_length = 3 - template = np.random.randn(image_length * image_length) - - X, Y = cnxcn_dataset(template) - - expected_n = image_length**4 - assert X.shape[0] == expected_n - - -class TestGroupDataset: - """Tests for group_dataset function.""" - - @pytest.fixture - def dihedral_group(self): - """Create a DihedralGroup for testing.""" - from escnn.group import DihedralGroup - - return DihedralGroup(N=3) # D3 - - def test_output_shape(self, dihedral_group): - """Test that output shapes are correct for D3.""" - group_order = dihedral_group.order() # 6 for D3 - template = np.random.randn(group_order) - - X, Y = group_dataset(dihedral_group, template) - - n_samples = group_order**2 - assert X.shape == (n_samples, 2, group_order), f"X shape mismatch: {X.shape}" - assert Y.shape == (n_samples, group_order), f"Y shape mismatch: {Y.shape}" - - def test_template_length_mismatch_error(self, dihedral_group): - """Test that mismatched template length raises error.""" - wrong_size = dihedral_group.order() + 1 - template = np.random.randn(wrong_size) - - with pytest.raises(AssertionError): - group_dataset(dihedral_group, template) - - -class TestMoveDatasetToDeviceAndFlatten: - """Tests for move_dataset_to_device_and_flatten function.""" - - def test_output_shape_and_type(self): - """Test that output shapes and types are correct.""" - group_size = 5 - n_samples = 10 - - X = np.random.randn(n_samples, 2, group_size) - Y = np.random.randn(n_samples, group_size) - - X_tensor, Y_tensor, device = move_dataset_to_device_and_flatten(X, Y, device="cpu") - - assert isinstance(X_tensor, torch.Tensor) - assert isinstance(Y_tensor, torch.Tensor) - assert X_tensor.shape == (n_samples, 2 * group_size) - assert Y_tensor.shape == (n_samples, group_size) - - def test_flattening(self): - """Test that X is correctly flattened.""" - group_size = 4 - n_samples = 5 - - X = np.arange(n_samples * 2 * group_size).reshape(n_samples, 2, group_size).astype(float) - Y = np.random.randn(n_samples, group_size) - - X_tensor, Y_tensor, device = move_dataset_to_device_and_flatten(X, Y, device="cpu") - - # Check first sample - expected_flat = np.concatenate([X[0, 0, :], X[0, 1, :]]) - np.testing.assert_allclose(X_tensor[0].numpy(), expected_flat) - - def test_device_cpu(self): - """Test explicit CPU device.""" - X = np.random.randn(5, 2, 4) - Y = np.random.randn(5, 4) - - X_tensor, Y_tensor, device = move_dataset_to_device_and_flatten(X, Y, device="cpu") - - assert X_tensor.device.type == "cpu" - assert Y_tensor.device.type == "cpu" diff --git a/test/test_bal_group_fourier_transform.py b/test/test_bal_group_fourier_transform.py deleted file mode 100644 index 306fa89..0000000 --- a/test/test_bal_group_fourier_transform.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np -from escnn.group import Octahedral - -from src.group_fourier_transform import ( - compute_group_fourier_transform, - compute_group_inverse_fourier_transform, -) -from src.templates import fixed_group_template - - -def test_fourier_inverse_is_identity(): - group = Octahedral() - seed = 42 - - # Generate template with nontrivial spectrum - template = fixed_group_template(group, fourier_coef_diag_values=[100.0, 20.0, 0.0, 0.0, 0.0]) - - # Forward Fourier transform - fourier_transform = compute_group_fourier_transform(group, template) - - # Inverse Fourier transform - reconstructed = compute_group_inverse_fourier_transform(group, fourier_transform) - - # Perform Fourier transform of the reconstructed template - fourier_transform_reconstructed = compute_group_fourier_transform(group, reconstructed) - - # Check that the original and reconstructed template are close - assert np.allclose(template, reconstructed, atol=1e-10), ( - f"Inversion failed! max diff: {np.max(np.abs(template - reconstructed))}" - ) - print(f"diff: {(np.abs(template - reconstructed))}") - - # Check that the Fourier transform of the reconstructed template is close to the original Fourier transform - print(f"fourier_transform: {[ft.shape for ft in fourier_transform]}") - print( - f"fourier_transform_reconstructed: {[ft.shape for ft in fourier_transform_reconstructed]}" - ) - assert len(fourier_transform) == len(fourier_transform_reconstructed), ( - f"Length mismatch: {len(fourier_transform)} vs {len(fourier_transform_reconstructed)}" - ) - for i, (ft, ft_rec) in enumerate(zip(fourier_transform, fourier_transform_reconstructed)): - assert np.allclose(ft, ft_rec, atol=1e-10), ( - f"Fourier transform failed at index {i}! max diff: {np.max(np.abs(ft - ft_rec))}" - ) - print(f"diff at index {i}: {np.max(np.abs(ft - ft_rec))}") - - -if __name__ == "__main__": - test_fourier_inverse_is_identity() diff --git a/test/test_bal_main.py b/test/test_bal_main.py deleted file mode 100644 index 6bbd811..0000000 --- a/test/test_bal_main.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Tests for group_agf/binary_action_learning/main.py - -This module tests that the main() entry point runs successfully with minimal -configuration. Tests are only run when MAIN_TEST_MODE=1 environment variable -is set to avoid long-running tests in regular CI. - -Expected runtime: < 1 minute with MAIN_TEST_MODE=1 - -Usage: - MAIN_TEST_MODE=1 pytest test/test_bal_main.py -v -""" - -import os -import sys -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -# Check for MAIN_TEST_MODE -MAIN_TEST_MODE = os.environ.get("MAIN_TEST_MODE", "0") == "1" - -# Add test directory to path and register test_default_config BEFORE any imports -# that might trigger loading of group_agf.binary_action_learning.main -_test_dir = Path(__file__).parent -if str(_test_dir) not in sys.path: - sys.path.insert(0, str(_test_dir)) - -# Import and register test_default_config as 'default_config' in sys.modules -# This must happen before any import of group_agf.binary_action_learning.main -import test_default_config # noqa: E402 - -sys.modules["default_config"] = test_default_config - - -@pytest.fixture -def temp_save_dir(): - """Create a temporary directory for saving model outputs.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - -@pytest.fixture -def mock_wandb(): - """Mock wandb to avoid actual logging.""" - mock_run = MagicMock() - mock_run.id = "test_run_123" - mock_run.name = "test_run" - - mock_config = MagicMock() - - with ( - patch("wandb.init") as mock_init, - patch("wandb.config", mock_config), - patch("wandb.run", mock_run), - patch("wandb.log") as mock_log, - patch("wandb.finish") as mock_finish, - patch("wandb.Image") as mock_image, - ): - mock_init.return_value = mock_run - mock_image.return_value = MagicMock() - yield { - "init": mock_init, - "config": mock_config, - "run": mock_run, - "log": mock_log, - "finish": mock_finish, - "image": mock_image, - } - - -@pytest.fixture -def mock_plots(): - """Mock plot functions to skip visualization.""" - with ( - patch("group_agf.binary_action_learning.plot.plot_loss_curve") as mock_loss, - patch("group_agf.binary_action_learning.plot.plot_training_power_over_time") as mock_power, - patch("group_agf.binary_action_learning.plot.plot_neuron_weights") as mock_weights, - patch("group_agf.binary_action_learning.plot.plot_model_outputs") as mock_outputs, - ): - # Return mock figure objects - mock_fig = MagicMock() - mock_loss.return_value = mock_fig - mock_power.return_value = mock_fig - mock_weights.return_value = mock_fig - mock_outputs.return_value = mock_fig - yield { - "loss": mock_loss, - "power": mock_power, - "weights": mock_weights, - "outputs": mock_outputs, - } - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_main_run_cn_group(temp_save_dir, mock_wandb, mock_plots): - """ - Test main_run() with a minimal cyclic group (C5) configuration. - - This tests the core training pipeline without the full main() iteration. - """ - # Update test_default_config to use temp directory - test_default_config.model_save_dir = temp_save_dir + "/" - - from group_agf.binary_action_learning.main import main_run - - # Create minimal config for C5 group - config = { - "group_name": "cn", - "group_size": 5, - "group_n": 5, - "epochs": 2, - "batch_size": 32, - "hidden_factor": 2, - "init_scale": 1e-2, - "lr": 0.01, - "mom": 0.9, - "optimizer_name": "SGD", - "seed": 42, - "verbose_interval": 1, - "model_save_dir": temp_save_dir + "/", - "dataset_fraction": 1.0, - "powers": [0, 10, 5], - "fourier_coef_diag_values": [0, 10, 5], - "power_logscale": False, - "resume_from_checkpoint": False, - "checkpoint_interval": 1000, - "checkpoint_path": None, - "template_type": "one_hot", - "run_start_time": "test_run", - } - - # Run the training - main_run(config) - - # Verify wandb was called (at least once - may be called multiple times in some flows) - assert mock_wandb["init"].call_count >= 1 - assert mock_wandb["finish"].call_count >= 1 - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_main_entry_point(temp_save_dir, mock_wandb, mock_plots): - """ - Test the full main() entry point with mocked default_config. - - This tests what happens when you run `python main.py`. - """ - # Update model_save_dir in test_default_config to use temp directory - test_default_config.model_save_dir = temp_save_dir + "/" - - # Import the main module (default_config is already mocked at module level) - from group_agf.binary_action_learning.main import main - - # Run main() - main() - - # Verify wandb was called (at least once for the single config combination) - assert mock_wandb["init"].call_count >= 1 - assert mock_wandb["finish"].call_count >= 1 - - -@pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") -def test_main_run_with_adam_optimizer(temp_save_dir, mock_wandb, mock_plots): - """Test main_run() with Adam optimizer.""" - # Update test_default_config to use temp directory - test_default_config.model_save_dir = temp_save_dir + "/" - - from group_agf.binary_action_learning.main import main_run - - config = { - "group_name": "cn", - "group_size": 5, - "group_n": 5, - "epochs": 2, - "batch_size": 32, - "hidden_factor": 2, - "init_scale": 1e-2, - "lr": 0.001, - "mom": 0.9, - "optimizer_name": "Adam", - "seed": 42, - "verbose_interval": 1, - "model_save_dir": temp_save_dir + "/", - "dataset_fraction": 1.0, - "powers": [0, 10, 5], - "fourier_coef_diag_values": [0, 10, 5], - "power_logscale": False, - "resume_from_checkpoint": False, - "checkpoint_interval": 1000, - "checkpoint_path": None, - "template_type": "one_hot", - "run_start_time": "test_run", - } - - main_run(config) - - assert mock_wandb["init"].call_count >= 1 - assert mock_wandb["finish"].call_count >= 1 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/test/test_bal_power.py b/test/test_bal_power.py deleted file mode 100644 index 72f4723..0000000 --- a/test/test_bal_power.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np -from escnn.group import Octahedral - -from src.power import GroupPower -from src.templates import fixed_group_template - - -def test_power_custom_template(): - group = Octahedral() - irrep_sizes = [irrep.size for irrep in group.irreps()] - print("Irrep sizes:", irrep_sizes) - seed = 42 - powers = [0.0, 20.0, 20.0, 100.0, 0.0] # on irreps [1, 3, 3, 2, 1] - fourier_coef_diag_values = [ - np.sqrt(group.order() * p / dim**2) for p, dim in zip(powers, irrep_sizes) - ] - template = fixed_group_template(group, fourier_coef_diag_values=fourier_coef_diag_values) - - gp = GroupPower(template, group) - power = gp.power - expected_powers = powers - - print("Computed power spectrum:", power) - print("Expected powers:", expected_powers) - print("Max diff:", np.max(np.abs(power - expected_powers))) - assert np.allclose(power, expected_powers), "Power spectrum does not match expected values" diff --git a/test/test_bal_templates.py b/test/test_bal_templates.py deleted file mode 100644 index e71209f..0000000 --- a/test/test_bal_templates.py +++ /dev/null @@ -1,156 +0,0 @@ -"""Tests for src.templates module.""" - -import numpy as np -import pytest - -from src.templates import ( - fixed_cn_template, - fixed_cnxcn_template, - fixed_group_template, - one_hot, -) - - -class TestOneHot: - """Tests for one_hot function.""" - - def test_output_shape(self): - """Test that output shape is correct.""" - p = 7 - template = one_hot(p) - - assert template.shape == (p,), f"Expected shape ({p},), got {template.shape}" - - def test_mean_centered(self): - """Test that the template is mean-centered.""" - p = 10 - template = one_hot(p) - - np.testing.assert_allclose(template.mean(), 0, atol=1e-10) - - def test_has_spike(self): - """Test that template has a spike at index 1.""" - p = 5 - template = one_hot(p) - - # The spike value should be 10 - mean - zeroth_freq = 10 / p # Mean of array with value 10 at index 1 - expected_spike_value = 10 - zeroth_freq - - np.testing.assert_allclose(template[1], expected_spike_value, rtol=1e-5) - - -class TestFixedCnTemplate: - """Tests for fixed_cn_template function.""" - - def test_output_shape(self): - """Test that output shape is correct.""" - group_size = 8 - fourier_coef_mags = [0, 5, 3, 2, 1] # Include DC and some frequencies - - template = fixed_cn_template(group_size, fourier_coef_mags) - - assert template.shape == (group_size,), ( - f"Expected shape ({group_size},), got {template.shape}" - ) - - def test_mean_centered(self): - """Test that the template is mean-centered.""" - group_size = 10 - fourier_coef_mags = [0, 5, 3, 2] - - template = fixed_cn_template(group_size, fourier_coef_mags) - - np.testing.assert_allclose(template.mean(), 0, atol=1e-10) - - def test_real_valued(self): - """Test that the template is real-valued.""" - group_size = 8 - fourier_coef_mags = [0, 5, 3] - - template = fixed_cn_template(group_size, fourier_coef_mags) - - # Template should be real (no imaginary component) - assert np.isreal(template).all() - - -class TestFixedCnxcnTemplate: - """Tests for fixed_cnxcn_template function.""" - - def test_output_shape(self): - """Test that output shape is correct (flattened).""" - image_length = 6 - fourier_coef_mags = [0, 5, 3, 2] # DC and some frequencies - - template = fixed_cnxcn_template(image_length, fourier_coef_mags) - - expected_size = image_length * image_length - assert template.shape == (expected_size,), ( - f"Expected shape ({expected_size},), got {template.shape}" - ) - - def test_mean_centered(self): - """Test that the template is mean-centered.""" - image_length = 5 - fourier_coef_mags = [0, 5, 3] - - template = fixed_cnxcn_template(image_length, fourier_coef_mags) - - np.testing.assert_allclose(template.mean(), 0, atol=1e-10) - - def test_real_valued(self): - """Test that the template is real-valued.""" - image_length = 4 - fourier_coef_mags = [0, 5] - - template = fixed_cnxcn_template(image_length, fourier_coef_mags) - - assert np.isreal(template).all() - - -class TestFixedGroupTemplate: - """Tests for fixed_group_template function.""" - - @pytest.fixture - def dihedral_group(self): - """Create a DihedralGroup for testing.""" - from escnn.group import DihedralGroup - - return DihedralGroup(N=3) # D3 has 5 irreps - - def test_output_shape(self, dihedral_group): - """Test that output shape matches group order.""" - group_order = dihedral_group.order() # 6 for D3 - num_irreps = len(list(dihedral_group.irreps())) # 5 for D3 - fourier_coef_diag_values = [1.0] * num_irreps - - template = fixed_group_template(dihedral_group, fourier_coef_diag_values) - - assert template.shape == (group_order,), ( - f"Expected shape ({group_order},), got {template.shape}" - ) - - def test_mean_centered(self, dihedral_group): - """Test that the template is mean-centered.""" - num_irreps = len(list(dihedral_group.irreps())) - fourier_coef_diag_values = [1.0] * num_irreps - - template = fixed_group_template(dihedral_group, fourier_coef_diag_values) - - np.testing.assert_allclose(template.mean(), 0, atol=1e-10) - - def test_wrong_num_coefs_error(self, dihedral_group): - """Test that mismatched number of coefficients raises error.""" - wrong_num_coefs = [1.0, 2.0] # Wrong number - - with pytest.raises(AssertionError): - fixed_group_template(dihedral_group, wrong_num_coefs) - - def test_real_valued(self, dihedral_group): - """Test that the template is real-valued.""" - num_irreps = len(list(dihedral_group.irreps())) - fourier_coef_diag_values = [1.0] * num_irreps - - template = fixed_group_template(dihedral_group, fourier_coef_diag_values) - - assert np.isreal(template).all() diff --git a/test/test_rnns_config.yaml b/test/test_config.yaml similarity index 93% rename from test/test_rnns_config.yaml rename to test/test_config.yaml index 1f0250a..f4767e4 100644 --- a/test/test_rnns_config.yaml +++ b/test/test_config.yaml @@ -1,5 +1,5 @@ # Minimal test configuration for src/main.py -# Used by test_rnns_main.py for fast testing +# Used by test_main.py for fast testing data: group_name: cn diff --git a/test/test_rnns_datamodule.py b/test/test_dataset.py similarity index 58% rename from test/test_rnns_datamodule.py rename to test/test_dataset.py index 0db087b..f49e38d 100644 --- a/test/test_rnns_datamodule.py +++ b/test/test_dataset.py @@ -1,15 +1,10 @@ -"""Tests for gagf.rnns.datamodule module.""" +"""Tests for src.dataset module.""" import numpy as np import pytest +import torch -from src.datamodule import ( - OnlineModularAdditionDataset1D, - OnlineModularAdditionDataset2D, - build_modular_addition_sequence_dataset_1d, - build_modular_addition_sequence_dataset_2d, - build_modular_addition_sequence_dataset_D3, -) +import src.dataset as dataset class TestBuildModularAdditionSequenceDataset1D: @@ -28,7 +23,7 @@ def test_output_shape_sampled(self, template_1d): k = 3 num_samples = 100 - X, Y, sequence = build_modular_addition_sequence_dataset_1d( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_1d( p=p, template=template_1d, k=k, mode="sampled", num_samples=num_samples ) @@ -41,7 +36,7 @@ def test_output_shape_exhaustive(self, template_1d): p = len(template_1d) k = 2 - X, Y, sequence = build_modular_addition_sequence_dataset_1d( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_1d( p=p, template=template_1d, k=k, mode="exhaustive" ) @@ -56,7 +51,7 @@ def test_output_shape_return_all_outputs(self, template_1d): k = 4 num_samples = 50 - X, Y, sequence = build_modular_addition_sequence_dataset_1d( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_1d( p=p, template=template_1d, k=k, @@ -75,7 +70,7 @@ def test_rolling_correctness(self, template_1d): p = len(template_1d) k = 2 - X, Y, sequence = build_modular_addition_sequence_dataset_1d( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_1d( p=p, template=template_1d, k=k, mode="exhaustive" ) @@ -101,7 +96,7 @@ def test_output_shape_sampled(self, template_2d): k = 3 num_samples = 100 - X, Y, sequence_xy = build_modular_addition_sequence_dataset_2d( + X, Y, sequence_xy = dataset.build_modular_addition_sequence_dataset_2d( p1=p1, p2=p2, template=template_2d, k=k, mode="sampled", num_samples=num_samples ) @@ -120,7 +115,7 @@ def test_output_shape_exhaustive(self, template_2d): template = np.random.randn(p1, p2).astype(np.float32) k = 2 - X, Y, sequence_xy = build_modular_addition_sequence_dataset_2d( + X, Y, sequence_xy = dataset.build_modular_addition_sequence_dataset_2d( p1=p1, p2=p2, template=template, k=k, mode="exhaustive" ) @@ -147,7 +142,7 @@ def test_output_shape_sampled(self, template_d3): num_samples = 100 group_order = len(template_d3) - X, Y, sequence = build_modular_addition_sequence_dataset_D3( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_D3( template=template_d3, k=k, mode="sampled", num_samples=num_samples ) @@ -161,7 +156,7 @@ def test_output_shape_exhaustive(self, template_d3): group_order = len(template_d3) n_elements = group_order # D3 has 6 elements - X, Y, sequence = build_modular_addition_sequence_dataset_D3( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_D3( template=template_d3, k=k, mode="exhaustive" ) @@ -176,7 +171,7 @@ def test_output_shape_return_all_outputs(self, template_d3): num_samples = 50 group_order = len(template_d3) - X, Y, sequence = build_modular_addition_sequence_dataset_D3( + X, Y, sequence = dataset.build_modular_addition_sequence_dataset_D3( template=template_d3, k=k, mode="sampled", @@ -199,12 +194,12 @@ def test_batch_shape(self): batch_size = 16 template = np.random.randn(p).astype(np.float32) - dataset = OnlineModularAdditionDataset1D( + ds = dataset.OnlineModularAdditionDataset1D( p=p, template=template, k=k, batch_size=batch_size, device="cpu" ) # Get first batch - iterator = iter(dataset) + iterator = iter(ds) X, Y = next(iterator) assert X.shape == (batch_size, k, p), f"X shape mismatch: {X.shape}" @@ -217,7 +212,7 @@ def test_batch_shape_return_all_outputs(self): batch_size = 16 template = np.random.randn(p).astype(np.float32) - dataset = OnlineModularAdditionDataset1D( + ds = dataset.OnlineModularAdditionDataset1D( p=p, template=template, k=k, @@ -226,7 +221,7 @@ def test_batch_shape_return_all_outputs(self): return_all_outputs=True, ) - iterator = iter(dataset) + iterator = iter(ds) X, Y = next(iterator) assert X.shape == (batch_size, k, p) @@ -243,11 +238,11 @@ def test_batch_shape(self): batch_size = 16 template = np.random.randn(p1, p2).astype(np.float32) - dataset = OnlineModularAdditionDataset2D( + ds = dataset.OnlineModularAdditionDataset2D( p1=p1, p2=p2, template=template, k=k, batch_size=batch_size, device="cpu" ) - iterator = iter(dataset) + iterator = iter(ds) X, Y = next(iterator) p_flat = p1 * p2 @@ -261,7 +256,7 @@ def test_batch_shape_return_all_outputs(self): batch_size = 16 template = np.random.randn(p1, p2).astype(np.float32) - dataset = OnlineModularAdditionDataset2D( + ds = dataset.OnlineModularAdditionDataset2D( p1=p1, p2=p2, template=template, @@ -271,9 +266,143 @@ def test_batch_shape_return_all_outputs(self): return_all_outputs=True, ) - iterator = iter(dataset) + iterator = iter(ds) X, Y = next(iterator) p_flat = p1 * p2 assert X.shape == (batch_size, k, p_flat) assert Y.shape == (batch_size, k - 1, p_flat) + + +class TestCnDataset: + """Tests for cn_dataset function.""" + + def test_output_shape(self): + """Test that output shapes are correct.""" + group_size = 7 + template = np.random.randn(group_size) + + X, Y = dataset.cn_dataset(template) + + n_samples = group_size**2 + assert X.shape == (n_samples, 2, group_size), f"X shape mismatch: {X.shape}" + assert Y.shape == (n_samples, group_size), f"Y shape mismatch: {Y.shape}" + + def test_modular_addition_property(self): + """Test that Y is the rolled template by (a+b) mod p.""" + group_size = 5 + template = np.arange(group_size).astype(float) + + X, Y = dataset.cn_dataset(template) + + # Check a specific case: a=1, b=2 -> q=(1+2)%5=3 + idx = 1 * group_size + 2 + expected_y = np.roll(template, 3) + np.testing.assert_allclose(Y[idx], expected_y) + + def test_covers_all_pairs(self): + """Test that all pairs (a, b) are covered.""" + group_size = 4 + template = np.random.randn(group_size) + + X, Y = dataset.cn_dataset(template) + + assert X.shape[0] == group_size**2 + + +class TestCnxcnDataset: + """Tests for cnxcn_dataset function.""" + + def test_output_shape(self): + """Test that output shapes are correct.""" + image_length = 4 + template = np.random.randn(image_length * image_length) + + X, Y = dataset.cnxcn_dataset(template) + + n_samples = image_length**4 + n_features = image_length * image_length + assert X.shape == (n_samples, 2, n_features), f"X shape mismatch: {X.shape}" + assert Y.shape == (n_samples, n_features), f"Y shape mismatch: {Y.shape}" + + def test_covers_all_combinations(self): + """Test that all combinations are covered.""" + image_length = 3 + template = np.random.randn(image_length * image_length) + + X, Y = dataset.cnxcn_dataset(template) + + expected_n = image_length**4 + assert X.shape[0] == expected_n + + +class TestGroupDataset: + """Tests for group_dataset function.""" + + @pytest.fixture + def dihedral_group(self): + """Create a DihedralGroup for testing.""" + from escnn.group import DihedralGroup + + return DihedralGroup(N=3) + + def test_output_shape(self, dihedral_group): + """Test that output shapes are correct for D3.""" + group_order = dihedral_group.order() + template = np.random.randn(group_order) + + X, Y = dataset.group_dataset(dihedral_group, template) + + n_samples = group_order**2 + assert X.shape == (n_samples, 2, group_order), f"X shape mismatch: {X.shape}" + assert Y.shape == (n_samples, group_order), f"Y shape mismatch: {Y.shape}" + + def test_template_length_mismatch_error(self, dihedral_group): + """Test that mismatched template length raises error.""" + wrong_size = dihedral_group.order() + 1 + template = np.random.randn(wrong_size) + + with pytest.raises(AssertionError): + dataset.group_dataset(dihedral_group, template) + + +class TestMoveDatasetToDeviceAndFlatten: + """Tests for move_dataset_to_device_and_flatten function.""" + + def test_output_shape_and_type(self): + """Test that output shapes and types are correct.""" + group_size = 5 + n_samples = 10 + + X = np.random.randn(n_samples, 2, group_size) + Y = np.random.randn(n_samples, group_size) + + X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y, device="cpu") + + assert isinstance(X_tensor, torch.Tensor) + assert isinstance(Y_tensor, torch.Tensor) + assert X_tensor.shape == (n_samples, 2 * group_size) + assert Y_tensor.shape == (n_samples, group_size) + + def test_flattening(self): + """Test that X is correctly flattened.""" + group_size = 4 + n_samples = 5 + + X = np.arange(n_samples * 2 * group_size).reshape(n_samples, 2, group_size).astype(float) + Y = np.random.randn(n_samples, group_size) + + X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y, device="cpu") + + expected_flat = np.concatenate([X[0, 0, :], X[0, 1, :]]) + np.testing.assert_allclose(X_tensor[0].numpy(), expected_flat) + + def test_device_cpu(self): + """Test explicit CPU device.""" + X = np.random.randn(5, 2, 4) + Y = np.random.randn(5, 4) + + X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y, device="cpu") + + assert X_tensor.device.type == "cpu" + assert Y_tensor.device.type == "cpu" diff --git a/test/test_default_config.py b/test/test_default_config.py deleted file mode 100644 index b46bde4..0000000 --- a/test/test_default_config.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Minimal test configuration for binary_action_learning/main.py - -This module mimics the structure of group_agf/binary_action_learning/default_config.py -but with minimal values for fast testing. -""" - -# Dataset Parameters -group_name = "cn" -group_n = [5] # Small cyclic group C5 -template_type = "one_hot" - -powers = { - "cn": [[0, 10, 5]], # Single power configuration - "cnxcn": [[0, 10, 5]], - "dihedral": [[0, 5, 0]], - "octahedral": [[0, 10, 0, 0, 0]], - "A5": [[0, 10, 0, 0, 0]], -} - -# Model Parameters -hidden_factor = [2] # Small hidden size - -# Learning Parameters -seed = [42] -init_scale = { - "cn": [1e-2], - "cnxcn": [1e-2], - "dihedral": [1e-2], - "octahedral": [1e-3], - "A5": [1e-3], -} -lr = { - "cn": [0.01], - "cnxcn": [0.01], - "dihedral": [0.01], - "octahedral": [0.001], - "A5": [0.001], -} -mom = [0.9] -optimizer_name = ["SGD"] # Simple optimizer -epochs = [2] # Minimal epochs -verbose_interval = 1 -checkpoint_interval = 1000 -batch_size = [32] # Small batch size - -# Plotting parameters -power_logscale = False - -# Checkpoint settings -resume_from_checkpoint = False -checkpoint_epoch = 0 - -# cnxcn specific parameters -image_length = [3] # Small image for cnxcn - -dataset_fraction = { - "cn": 1.0, - "cnxcn": 1.0, - "dihedral": 1.0, - "octahedral": 1.0, - "A5": 1.0, -} - -# Use temp directory - will be overwritten in tests -model_save_dir = "/tmp/test_bal/" diff --git a/test/test_fourier.py b/test/test_fourier.py new file mode 100644 index 0000000..106ee26 --- /dev/null +++ b/test/test_fourier.py @@ -0,0 +1,36 @@ +import numpy as np +from escnn.group import Octahedral + +import src.fourier as fourier +import src.template as template + + +def test_group_fourier_inverse_is_identity(): + """Test that group_fourier followed by group_fourier_inverse reconstructs the original.""" + group = Octahedral() + + tpl = template.fixed_group(group, fourier_coef_diag_values=[100.0, 20.0, 0.0, 0.0, 0.0]) + + fourier_coefs = fourier.group_fourier(group, tpl) + reconstructed = fourier.group_fourier_inverse(group, fourier_coefs) + + assert np.allclose(tpl, reconstructed, atol=1e-10), ( + f"Inversion failed! max diff: {np.max(np.abs(tpl - reconstructed))}" + ) + + +def test_group_fourier_coefs_shape(): + """Test that group_fourier returns one coefficient matrix per irrep.""" + group = Octahedral() + tpl = template.fixed_group(group, fourier_coef_diag_values=[100.0, 20.0, 0.0, 0.0, 0.0]) + + fourier_coefs = fourier.group_fourier(group, tpl) + + assert len(fourier_coefs) == len(group.irreps()) + for coef, irrep in zip(fourier_coefs, group.irreps()): + assert coef.shape == (irrep.size, irrep.size) + + +if __name__ == "__main__": + test_group_fourier_inverse_is_identity() + test_group_fourier_coefs_shape() diff --git a/test/test_main.py b/test/test_main.py index 08adf66..e598027 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -45,7 +45,7 @@ def temp_run_dir(): @pytest.fixture def mock_all_plots(): """Mock all produce_plots_* and plt.savefig/close to skip visualization entirely.""" - import src.main # noqa: F401 + import src.main as main # noqa: F401 with ( patch("src.main.produce_plots_1d") as mock_1d, @@ -76,9 +76,9 @@ def mock_savefig(): @pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") def test_load_config(): """Test that load_config correctly loads a YAML file.""" - from src.main import load_config + import src.main as main - config = load_config(str(CONFIG_FILES["c10"])) + config = main.load_config(str(CONFIG_FILES["c10"])) assert "data" in config assert "model" in config @@ -93,10 +93,10 @@ def test_load_config(): @pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") def test_main_c10(temp_run_dir, mock_all_plots): """Test main() with C_10 cyclic group config.""" - from src.main import load_config, train_single_run + import src.main as main - config = load_config(str(CONFIG_FILES["c10"])) - results = train_single_run(config, run_dir=temp_run_dir) + config = main.load_config(str(CONFIG_FILES["c10"])) + results = main.train_single_run(config, run_dir=temp_run_dir) assert "final_train_loss" in results assert "final_val_loss" in results @@ -107,10 +107,10 @@ def test_main_c10(temp_run_dir, mock_all_plots): @pytest.mark.skipif(not MAIN_TEST_MODE, reason="Only run with MAIN_TEST_MODE=1") def test_main_c4x4(temp_run_dir, mock_all_plots): """Test main() with C_4 x C_4 product group config.""" - from src.main import load_config, train_single_run + import src.main as main - config = load_config(str(CONFIG_FILES["c4x4"])) - results = train_single_run(config, run_dir=temp_run_dir) + config = main.load_config(str(CONFIG_FILES["c4x4"])) + results = main.train_single_run(config, run_dir=temp_run_dir) assert "final_train_loss" in results assert "final_val_loss" in results @@ -128,10 +128,10 @@ def test_main_d3(temp_run_dir, mock_savefig): This validates the TwoLayerNet-compatible eval data path in produce_plots_group, which is shared by octahedral and A5 (mocked in their tests for speed). """ - from src.main import load_config, train_single_run + import src.main as main - config = load_config(str(CONFIG_FILES["d3"])) - results = train_single_run(config, run_dir=temp_run_dir) + config = main.load_config(str(CONFIG_FILES["d3"])) + results = main.train_single_run(config, run_dir=temp_run_dir) assert "final_train_loss" in results assert "final_val_loss" in results @@ -145,10 +145,10 @@ def test_main_octahedral(temp_run_dir, mock_all_plots): Mocks produce_plots_group for speed (octahedral order=24, plotting is expensive). Training + data pipeline still fully exercised. """ - from src.main import load_config, train_single_run + import src.main as main - config = load_config(str(CONFIG_FILES["octahedral"])) - results = train_single_run(config, run_dir=temp_run_dir) + config = main.load_config(str(CONFIG_FILES["octahedral"])) + results = main.train_single_run(config, run_dir=temp_run_dir) assert "final_train_loss" in results assert "final_val_loss" in results @@ -163,10 +163,10 @@ def test_main_a5(temp_run_dir, mock_all_plots): Mocks produce_plots_group for speed (A5 order=60, plotting is expensive). Training + data pipeline still fully exercised. """ - from src.main import load_config, train_single_run + import src.main as main - config = load_config(str(CONFIG_FILES["a5"])) - results = train_single_run(config, run_dir=temp_run_dir) + config = main.load_config(str(CONFIG_FILES["a5"])) + results = main.train_single_run(config, run_dir=temp_run_dir) assert "final_train_loss" in results assert "final_val_loss" in results diff --git a/test/test_rnns_model.py b/test/test_model.py similarity index 73% rename from test/test_rnns_model.py rename to test/test_model.py index 9a92a10..722775e 100644 --- a/test/test_rnns_model.py +++ b/test/test_model.py @@ -3,125 +3,121 @@ import pytest import torch -from src.model import QuadraticRNN, SequentialMLP, TwoLayerNet +import src.model as model class TestQuadraticRNN: - """Tests for the QuadraticRNN model.""" + """Tests for model.QuadraticRNN.""" @pytest.fixture def default_params(self): """Default parameters for QuadraticRNN.""" p = 7 d = 10 - template = torch.randn(p) - return {"p": p, "d": d, "template": template} + tpl = torch.randn(p) + return {"p": p, "d": d, "template": tpl} def test_output_shape_basic(self, default_params): """Test that output shape is correct for basic forward pass.""" - model = QuadraticRNN(**default_params) + net = model.QuadraticRNN(**default_params) batch_size = 8 k = 4 p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) assert y.shape == (batch_size, p), f"Expected shape {(batch_size, p)}, got {y.shape}" def test_output_shape_return_all_outputs(self, default_params): """Test output shape when return_all_outputs=True.""" params = {**default_params, "return_all_outputs": True} - model = QuadraticRNN(**params) + net = model.QuadraticRNN(**params) batch_size = 8 k = 5 p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) - # With return_all_outputs=True, we get k-1 outputs (after first two tokens) expected_shape = (batch_size, k - 1, p) assert y.shape == expected_shape, f"Expected shape {expected_shape}, got {y.shape}" def test_output_shape_k_equals_2(self, default_params): """Test output shape when k=2 (minimum sequence length).""" - model = QuadraticRNN(**default_params) + net = model.QuadraticRNN(**default_params) batch_size = 4 k = 2 p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) assert y.shape == (batch_size, p) def test_quadratic_transform(self, default_params): """Test that quadratic transform is applied correctly.""" params = {**default_params, "transform_type": "quadratic"} - model = QuadraticRNN(**params) + net = model.QuadraticRNN(**params) batch_size = 2 k = 3 p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) - # Output should be finite assert torch.isfinite(y).all(), "Output contains non-finite values" def test_multiplicative_transform(self, default_params): """Test that multiplicative transform is applied correctly.""" params = {**default_params, "transform_type": "multiplicative"} - model = QuadraticRNN(**params) + net = model.QuadraticRNN(**params) batch_size = 2 k = 3 p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) - # Output should be finite assert torch.isfinite(y).all(), "Output contains non-finite values" def test_invalid_transform_type(self, default_params): """Test that invalid transform type raises an error.""" params = {**default_params, "transform_type": "invalid"} - model = QuadraticRNN(**params) + net = model.QuadraticRNN(**params) x = torch.randn(2, 3, default_params["p"]) with pytest.raises(ValueError, match="Invalid transform type"): - model(x) + net(x) def test_minimum_sequence_length_error(self, default_params): """Test that k<2 raises an assertion error.""" - model = QuadraticRNN(**default_params) + net = model.QuadraticRNN(**default_params) x = torch.randn(2, 1, default_params["p"]) # k=1 with pytest.raises(AssertionError, match="Sequence length must be at least 2"): - model(x) + net(x) def test_gradient_flow(self, default_params): """Test that gradients flow through the model.""" - model = QuadraticRNN(**default_params) + net = model.QuadraticRNN(**default_params) x = torch.randn(4, 3, default_params["p"], requires_grad=True) - y = model(x) + y = net(x) loss = y.sum() loss.backward() - # Check that gradients exist for all parameters - for name, param in model.named_parameters(): + for name, param in net.named_parameters(): assert param.grad is not None, f"No gradient for {name}" assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" class TestSequentialMLP: - """Tests for the SequentialMLP model.""" + """Tests for model.SequentialMLP.""" @pytest.fixture def default_params(self): @@ -129,71 +125,69 @@ def default_params(self): p = 7 d = 10 k = 3 - template = torch.randn(p) - return {"p": p, "d": d, "k": k, "template": template} + tpl = torch.randn(p) + return {"p": p, "d": d, "k": k, "template": tpl} def test_output_shape(self, default_params): """Test that output shape is correct.""" - model = SequentialMLP(**default_params) + net = model.SequentialMLP(**default_params) batch_size = 8 k = default_params["k"] p = default_params["p"] x = torch.randn(batch_size, k, p) - y = model(x) + y = net(x) assert y.shape == (batch_size, p), f"Expected shape {(batch_size, p)}, got {y.shape}" def test_k_mismatch_error(self, default_params): """Test that mismatched k raises an error.""" - model = SequentialMLP(**default_params) + net = model.SequentialMLP(**default_params) wrong_k = default_params["k"] + 1 x = torch.randn(2, wrong_k, default_params["p"]) with pytest.raises(AssertionError, match="Expected k="): - model(x) + net(x) def test_different_k_values(self): """Test model with different k values.""" p = 5 d = 8 - template = torch.randn(p) + tpl = torch.randn(p) for k in [2, 3, 4, 5]: - model = SequentialMLP(p=p, d=d, k=k, template=template) + net = model.SequentialMLP(p=p, d=d, k=k, template=tpl) x = torch.randn(4, k, p) - y = model(x) + y = net(x) assert y.shape == (4, p), f"Failed for k={k}" def test_gradient_flow(self, default_params): """Test that gradients flow through the model.""" - model = SequentialMLP(**default_params) + net = model.SequentialMLP(**default_params) x = torch.randn(4, default_params["k"], default_params["p"], requires_grad=True) - y = model(x) + y = net(x) loss = y.sum() loss.backward() - # Check that gradients exist for all parameters - for name, param in model.named_parameters(): + for name, param in net.named_parameters(): assert param.grad is not None, f"No gradient for {name}" assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" def test_k_power_activation(self, default_params): """Test that k-th power activation produces finite results.""" - model = SequentialMLP(**default_params) + net = model.SequentialMLP(**default_params) - # Use small inputs to avoid overflow with k-th power x = torch.randn(4, default_params["k"], default_params["p"]) * 0.1 - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" class TestTwoLayerNet: - """Tests for the TwoLayerNet model.""" + """Tests for model.TwoLayerNet.""" @pytest.fixture def default_params(self): @@ -202,13 +196,12 @@ def default_params(self): def test_output_shape(self, default_params): """Test that output shape is correct.""" - model = TwoLayerNet(**default_params) + net = model.TwoLayerNet(**default_params) batch_size = 8 group_size = default_params["group_size"] - # Input is flattened: (batch, 2 * group_size) x = torch.randn(batch_size, 2 * group_size) - y = model(x) + y = net(x) assert y.shape == ( batch_size, @@ -218,100 +211,96 @@ def test_output_shape(self, default_params): def test_square_nonlinearity(self, default_params): """Test that square nonlinearity produces finite results.""" params = {**default_params, "nonlinearity": "square"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" def test_relu_nonlinearity(self, default_params): """Test that relu nonlinearity produces finite results.""" params = {**default_params, "nonlinearity": "relu"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" def test_tanh_nonlinearity(self, default_params): """Test that tanh nonlinearity produces finite results.""" params = {**default_params, "nonlinearity": "tanh"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" def test_gelu_nonlinearity(self, default_params): """Test that gelu nonlinearity produces finite results.""" params = {**default_params, "nonlinearity": "gelu"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" def test_linear_nonlinearity(self, default_params): """Test that linear (no activation) produces finite results.""" params = {**default_params, "nonlinearity": "linear"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) - y = model(x) + y = net(x) assert torch.isfinite(y).all(), "Output contains non-finite values" def test_invalid_nonlinearity(self, default_params): """Test that invalid nonlinearity raises an error.""" params = {**default_params, "nonlinearity": "invalid"} - model = TwoLayerNet(**params) + net = model.TwoLayerNet(**params) x = torch.randn(4, 2 * default_params["group_size"]) with pytest.raises(ValueError, match="Invalid nonlinearity"): - model(x) + net(x) def test_gradient_flow(self, default_params): """Test that gradients flow through the model.""" - model = TwoLayerNet(**default_params) + net = model.TwoLayerNet(**default_params) x = torch.randn(4, 2 * default_params["group_size"], requires_grad=True) - y = model(x) + y = net(x) loss = y.sum() loss.backward() - # Check that gradients exist for all parameters - for name, param in model.named_parameters(): + for name, param in net.named_parameters(): assert param.grad is not None, f"No gradient for {name}" assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" def test_default_hidden_size(self): """Test that default hidden_size is computed correctly.""" group_size = 8 - model = TwoLayerNet(group_size=group_size) + net = model.TwoLayerNet(group_size=group_size) - # Default hidden_size should be 50 * group_size - assert model.hidden_size == 50 * group_size + assert net.hidden_size == 50 * group_size def test_output_scale(self, default_params): """Test that output_scale affects the output magnitude.""" scale_small = 0.1 scale_large = 10.0 - # Same random seed for reproducibility torch.manual_seed(42) - model_small = TwoLayerNet(**default_params, output_scale=scale_small) + net_small = model.TwoLayerNet(**default_params, output_scale=scale_small) torch.manual_seed(42) - model_large = TwoLayerNet(**default_params, output_scale=scale_large) + net_large = model.TwoLayerNet(**default_params, output_scale=scale_large) x = torch.randn(4, 2 * default_params["group_size"]) - y_small = model_small(x) - y_large = model_large(x) + y_small = net_small(x) + y_large = net_large(x) - # Output with larger scale should have larger absolute values on average assert y_large.abs().mean() > y_small.abs().mean() diff --git a/test/test_notebooks.py b/test/test_notebooks.py index 1e8ea2e..a87daf6 100644 --- a/test/test_notebooks.py +++ b/test/test_notebooks.py @@ -45,7 +45,7 @@ def get_notebooks_dir(): # These notebooks require pre-trained model files or external data "paper_figures": "Requires pre-trained model .pkl files not included in repo", # These notebooks have import/code issues that need separate debugging - "2D": "Missing function: cannot import 'get_power_2d' from src.utils", + "2D": "Missing function: cannot import 'get_power_2d' from src.power", "znz_znz": "Missing function: datasets.choose_template() does not exist", "seq_mlp": "Plotting error: Invalid vmin/vmax values during visualization", # These notebooks have visualization code with hardcoded indices that fail with reduced p diff --git a/test/test_rnns_optimizers.py b/test/test_optimizer.py similarity index 59% rename from test/test_rnns_optimizers.py rename to test/test_optimizer.py index 8eb6bc7..82fcb89 100644 --- a/test/test_rnns_optimizers.py +++ b/test/test_optimizer.py @@ -1,14 +1,14 @@ -"""Tests for gagf.rnns.optimizers module.""" +"""Tests for src.optimizer module.""" import pytest import torch -from src.model import QuadraticRNN, SequentialMLP -from src.optimizers import HybridRNNOptimizer, PerNeuronScaledSGD +import src.model as model +import src.optimizer as optimizer class TestPerNeuronScaledSGD: - """Tests for PerNeuronScaledSGD optimizer.""" + """Tests for optimizer.PerNeuronScaledSGD.""" @pytest.fixture def sequential_mlp(self): @@ -16,140 +16,123 @@ def sequential_mlp(self): p = 5 d = 10 k = 3 - template = torch.randn(p) - return SequentialMLP(p=p, d=d, k=k, template=template) + tpl = torch.randn(p) + return model.SequentialMLP(p=p, d=d, k=k, template=tpl) def test_step_updates_parameters(self, sequential_mlp): """Test that optimizer step updates model parameters.""" - optimizer = PerNeuronScaledSGD(sequential_mlp, lr=0.01) + opt = optimizer.PerNeuronScaledSGD(sequential_mlp, lr=0.01) - # Store initial parameters initial_w_in = sequential_mlp.W_in.clone() initial_w_out = sequential_mlp.W_out.clone() - # Forward pass and backward x = torch.randn(4, sequential_mlp.k, sequential_mlp.p) y = sequential_mlp(x) loss = y.sum() loss.backward() - # Optimizer step - optimizer.step() + opt.step() - # Parameters should have changed assert not torch.allclose(sequential_mlp.W_in, initial_w_in), "W_in not updated" assert not torch.allclose(sequential_mlp.W_out, initial_w_out), "W_out not updated" def test_degree_inference(self, sequential_mlp): """Test that degree is correctly inferred from model.""" - optimizer = PerNeuronScaledSGD(sequential_mlp, lr=0.01) + opt = optimizer.PerNeuronScaledSGD(sequential_mlp, lr=0.01) - # Degree should be k + 1 for SequentialMLP expected_degree = sequential_mlp.k + 1 - assert optimizer.defaults["degree"] == expected_degree + assert opt.defaults["degree"] == expected_degree def test_explicit_degree(self, sequential_mlp): """Test that explicit degree overrides inference.""" explicit_degree = 5 - optimizer = PerNeuronScaledSGD(sequential_mlp, lr=0.01, degree=explicit_degree) + opt = optimizer.PerNeuronScaledSGD(sequential_mlp, lr=0.01, degree=explicit_degree) - assert optimizer.defaults["degree"] == explicit_degree + assert opt.defaults["degree"] == explicit_degree def test_finite_gradients_after_step(self, sequential_mlp): """Test that gradients remain finite after optimization step.""" - optimizer = PerNeuronScaledSGD(sequential_mlp, lr=0.01) + opt = optimizer.PerNeuronScaledSGD(sequential_mlp, lr=0.01) x = torch.randn(4, sequential_mlp.k, sequential_mlp.p) y = sequential_mlp(x) loss = y.sum() loss.backward() - optimizer.step() + opt.step() - # All parameters should still be finite for name, param in sequential_mlp.named_parameters(): assert torch.isfinite(param).all(), f"Non-finite values in {name}" class TestHybridRNNOptimizer: - """Tests for HybridRNNOptimizer.""" + """Tests for optimizer.HybridRNNOptimizer.""" @pytest.fixture def quadratic_rnn(self): """Create a QuadraticRNN model.""" p = 5 d = 10 - template = torch.randn(p) - return QuadraticRNN(p=p, d=d, template=template) + tpl = torch.randn(p) + return model.QuadraticRNN(p=p, d=d, template=tpl) def test_step_updates_all_parameters(self, quadratic_rnn): """Test that optimizer step updates all model parameters.""" - optimizer = HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) + opt = optimizer.HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) - # Store initial parameters initial_params = {name: param.clone() for name, param in quadratic_rnn.named_parameters()} - # Forward pass and backward x = torch.randn(4, 3, quadratic_rnn.p) y = quadratic_rnn(x) loss = y.sum() loss.backward() - # Optimizer step - optimizer.step() + opt.step() - # All parameters should have changed for name, param in quadratic_rnn.named_parameters(): assert not torch.allclose(param, initial_params[name]), f"{name} not updated" def test_scaled_sgd_for_mlp_params(self, quadratic_rnn): """Test that W_in, W_drive, W_out use scaled SGD.""" - optimizer = HybridRNNOptimizer(quadratic_rnn, lr=0.01) + opt = optimizer.HybridRNNOptimizer(quadratic_rnn, lr=0.01) - # The optimizer should have two param groups - assert len(optimizer.param_groups) == 2 - - # First group should be scaled_sgd - assert optimizer.param_groups[0]["type"] == "scaled_sgd" - # Second group should be adam - assert optimizer.param_groups[1]["type"] == "adam" + assert len(opt.param_groups) == 2 + assert opt.param_groups[0]["type"] == "scaled_sgd" + assert opt.param_groups[1]["type"] == "adam" def test_adam_for_w_mix(self, quadratic_rnn): """Test that W_mix uses Adam optimizer.""" - optimizer = HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) + opt = optimizer.HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) - # W_mix should be in the adam group - adam_params = list(optimizer.param_groups[1]["params"]) + adam_params = list(opt.param_groups[1]["params"]) assert len(adam_params) == 1 assert adam_params[0] is quadratic_rnn.W_mix def test_finite_parameters_after_step(self, quadratic_rnn): """Test that parameters remain finite after optimization.""" - optimizer = HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) + opt = optimizer.HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) x = torch.randn(4, 3, quadratic_rnn.p) y = quadratic_rnn(x) loss = y.sum() loss.backward() - optimizer.step() + opt.step() - # All parameters should be finite for name, param in quadratic_rnn.named_parameters(): assert torch.isfinite(param).all(), f"Non-finite values in {name}" def test_multiple_steps(self, quadratic_rnn): """Test that multiple optimization steps work correctly.""" - optimizer = HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) + opt = optimizer.HybridRNNOptimizer(quadratic_rnn, lr=0.01, adam_lr=0.001) for _ in range(5): - optimizer.zero_grad() + opt.zero_grad() x = torch.randn(4, 3, quadratic_rnn.p) y = quadratic_rnn(x) loss = y.sum() loss.backward() - optimizer.step() + opt.step() - # All parameters should still be finite after multiple steps for name, param in quadratic_rnn.named_parameters(): assert torch.isfinite(param).all(), f"Non-finite values in {name}" diff --git a/test/test_rnns_utils.py b/test/test_power.py similarity index 53% rename from test/test_rnns_utils.py rename to test/test_power.py index 24b3857..a5b3472 100644 --- a/test/test_rnns_utils.py +++ b/test/test_power.py @@ -1,36 +1,33 @@ -"""Tests for gagf.rnns.utils module.""" +"""Tests for src.power module.""" import numpy as np +from escnn.group import Octahedral -from src.utils import ( - get_power_1d, - get_power_2d_adele, - topk_template_freqs, - topk_template_freqs_1d, -) +import src.power as power +import src.template as template class TestGetPower1D: - """Tests for get_power_1d function.""" + """Tests for power.get_power_1d function.""" def test_output_shape(self): """Test that output shape is correct.""" p = 10 signal = np.random.randn(p) - power, freqs = get_power_1d(signal) + pwr, freqs = power.get_power_1d(signal) expected_len = p // 2 + 1 - assert power.shape == (expected_len,), f"power shape mismatch: {power.shape}" + assert pwr.shape == (expected_len,), f"power shape mismatch: {pwr.shape}" assert freqs.shape == (expected_len,), f"freqs shape mismatch: {freqs.shape}" def test_parseval_theorem(self): - """Test that Parseval's theorem holds (total power ≈ norm squared).""" + """Test that Parseval's theorem holds (total power ~ norm squared).""" p = 16 signal = np.random.randn(p) - power, _ = get_power_1d(signal) - total_power = np.sum(power) + pwr, _ = power.get_power_1d(signal) + total_power = np.sum(pwr) norm_squared = np.linalg.norm(signal) ** 2 np.testing.assert_allclose( @@ -42,8 +39,8 @@ def test_parseval_theorem_odd_length(self): p = 15 signal = np.random.randn(p) - power, _ = get_power_1d(signal) - total_power = np.sum(power) + pwr, _ = power.get_power_1d(signal) + total_power = np.sum(pwr) norm_squared = np.linalg.norm(signal) ** 2 np.testing.assert_allclose( @@ -59,28 +56,26 @@ def test_dc_component(self): constant_value = 3.0 signal = np.full(p, constant_value) - power, freqs = get_power_1d(signal) + pwr, freqs = power.get_power_1d(signal) - # DC component should contain all the power for constant signal expected_dc_power = constant_value**2 * p - np.testing.assert_allclose(power[0], expected_dc_power, rtol=1e-6) + np.testing.assert_allclose(pwr[0], expected_dc_power, rtol=1e-6) - # All other components should be zero - assert np.allclose(power[1:], 0, atol=1e-10) + assert np.allclose(pwr[1:], 0, atol=1e-10) -class TestGetPower2DAdele: - """Tests for get_power_2d_adele function.""" +class TestGetPower2D: + """Tests for power.get_power_2d function.""" def test_output_shape(self): """Test that output shape is correct.""" M, N = 8, 10 signal = np.random.randn(M, N) - freqs_u, freqs_v, power = get_power_2d_adele(signal) + freqs_u, freqs_v, pwr = power.get_power_2d(signal) expected_power_shape = (M, N // 2 + 1) - assert power.shape == expected_power_shape, f"power shape mismatch: {power.shape}" + assert pwr.shape == expected_power_shape, f"power shape mismatch: {pwr.shape}" assert freqs_u.shape == (M,), f"freqs_u shape mismatch: {freqs_u.shape}" assert freqs_v.shape == (N // 2 + 1,), f"freqs_v shape mismatch: {freqs_v.shape}" @@ -89,9 +84,8 @@ def test_output_shape_no_freq(self): M, N = 8, 10 signal = np.random.randn(M, N) - result = get_power_2d_adele(signal, no_freq=True) + result = power.get_power_2d(signal, no_freq=True) - # Should only return power expected_shape = (M, N // 2 + 1) assert result.shape == expected_shape @@ -100,8 +94,8 @@ def test_parseval_theorem(self): M, N = 12, 12 signal = np.random.randn(M, N) - power = get_power_2d_adele(signal, no_freq=True) - total_power = np.sum(power) + pwr = power.get_power_2d(signal, no_freq=True) + total_power = np.sum(pwr) norm_squared = np.linalg.norm(signal) ** 2 np.testing.assert_allclose( @@ -110,11 +104,11 @@ def test_parseval_theorem(self): def test_parseval_theorem_rectangular(self): """Test Parseval's theorem for rectangular arrays.""" - M, N = 7, 11 # Both odd + M, N = 7, 11 signal = np.random.randn(M, N) - power = get_power_2d_adele(signal, no_freq=True) - total_power = np.sum(power) + pwr = power.get_power_2d(signal, no_freq=True) + total_power = np.sum(pwr) norm_squared = np.linalg.norm(signal) ** 2 np.testing.assert_allclose( @@ -126,15 +120,15 @@ def test_parseval_theorem_rectangular(self): class TestTopkTemplateFreqs1D: - """Tests for topk_template_freqs_1d function.""" + """Tests for power.topk_template_freqs_1d function.""" def test_returns_top_k(self): """Test that function returns exactly K frequencies.""" p = 16 K = 3 - template = np.random.randn(p) + tpl = np.random.randn(p) - top_freqs = topk_template_freqs_1d(template, K) + top_freqs = power.topk_template_freqs_1d(tpl, K) assert len(top_freqs) == K, f"Expected {K} frequencies, got {len(top_freqs)}" @@ -142,48 +136,45 @@ def test_returns_sorted_by_power(self): """Test that frequencies are sorted by descending power.""" p = 16 K = 5 - template = np.random.randn(p) + tpl = np.random.randn(p) - top_freqs = topk_template_freqs_1d(template, K) - power, _ = get_power_1d(template) + top_freqs = power.topk_template_freqs_1d(tpl, K) + pwr, _ = power.get_power_1d(tpl) - # Get powers for returned frequencies - returned_powers = [power[f] for f in top_freqs] + returned_powers = [pwr[f] for f in top_freqs] - # Should be in descending order assert returned_powers == sorted(returned_powers, reverse=True) def test_empty_for_zero_signal(self): """Test that zero signal with high min_power returns empty list.""" p = 8 - template = np.zeros(p) + tpl = np.zeros(p) - top_freqs = topk_template_freqs_1d(template, K=3, min_power=1e-10) + top_freqs = power.topk_template_freqs_1d(tpl, K=3, min_power=1e-10) assert top_freqs == [] def test_handles_k_larger_than_freqs(self): """Test behavior when K is larger than available frequencies.""" p = 6 - K = 10 # More than available frequencies - template = np.random.randn(p) + K = 10 + tpl = np.random.randn(p) - top_freqs = topk_template_freqs_1d(template, K) + top_freqs = power.topk_template_freqs_1d(tpl, K) - # Should return at most p//2 + 1 frequencies assert len(top_freqs) <= p // 2 + 1 class TestTopkTemplateFreqs: - """Tests for topk_template_freqs function (2D).""" + """Tests for power.topk_template_freqs function (2D).""" def test_returns_top_k(self): """Test that function returns exactly K frequency pairs.""" p1, p2 = 8, 8 K = 3 - template = np.random.randn(p1, p2) + tpl = np.random.randn(p1, p2) - top_freqs = topk_template_freqs(template, K) + top_freqs = power.topk_template_freqs(tpl, K) assert len(top_freqs) == K, f"Expected {K} frequency pairs, got {len(top_freqs)}" @@ -191,9 +182,9 @@ def test_returns_tuples(self): """Test that returned values are (kx, ky) tuples.""" p1, p2 = 8, 8 K = 3 - template = np.random.randn(p1, p2) + tpl = np.random.randn(p1, p2) - top_freqs = topk_template_freqs(template, K) + top_freqs = power.topk_template_freqs(tpl, K) for freq in top_freqs: assert isinstance(freq, tuple), f"Expected tuple, got {type(freq)}" @@ -202,8 +193,45 @@ def test_returns_tuples(self): def test_empty_for_zero_signal(self): """Test that zero signal returns empty list.""" p1, p2 = 6, 6 - template = np.zeros((p1, p2)) + tpl = np.zeros((p1, p2)) - top_freqs = topk_template_freqs(template, K=3, min_power=1e-10) + top_freqs = power.topk_template_freqs(tpl, K=3, min_power=1e-10) assert top_freqs == [] + + +class TestGroupPower: + """Tests for power.GroupPower class.""" + + def test_group_power_spectrum(self): + """Test that power.GroupPower computes correct power spectrum.""" + group = Octahedral() + irrep_sizes = [irrep.size for irrep in group.irreps()] + powers = [0.0, 20.0, 20.0, 100.0, 0.0] + fourier_coef_diag_values = [ + np.sqrt(group.order() * p / dim**2) for p, dim in zip(powers, irrep_sizes) + ] + tpl = template.fixed_group(group, fourier_coef_diag_values=fourier_coef_diag_values) + + gp = power.GroupPower(tpl, group) + + assert np.allclose(gp.power, powers), f"Power spectrum mismatch: {gp.power} vs {powers}" + + +class TestGroupPowerSpectrum: + """Tests for standalone power.group_power_spectrum function.""" + + def test_matches_class_method(self): + """Test that standalone function matches GroupPower class result.""" + group = Octahedral() + irrep_sizes = [irrep.size for irrep in group.irreps()] + powers = [0.0, 20.0, 20.0, 100.0, 0.0] + fourier_coef_diag_values = [ + np.sqrt(group.order() * p / dim**2) for p, dim in zip(powers, irrep_sizes) + ] + tpl = template.fixed_group(group, fourier_coef_diag_values=fourier_coef_diag_values) + + spectrum = power.group_power_spectrum(group, tpl) + gp = power.GroupPower(tpl, group) + + np.testing.assert_allclose(spectrum, gp.power, atol=1e-10) diff --git a/test/test_template.py b/test/test_template.py new file mode 100644 index 0000000..ba49e21 --- /dev/null +++ b/test/test_template.py @@ -0,0 +1,143 @@ +"""Tests for src.template module.""" + +import numpy as np +import pytest + +import src.template as template + + +class TestOneHot: + """Tests for template.one_hot function.""" + + def test_output_shape(self): + """Test that output shape is correct.""" + p = 7 + tpl = template.one_hot(p) + + assert tpl.shape == (p,), f"Expected shape ({p},), got {tpl.shape}" + + def test_mean_centered(self): + """Test that the template is mean-centered.""" + p = 10 + tpl = template.one_hot(p) + + np.testing.assert_allclose(tpl.mean(), 0, atol=1e-10) + + def test_has_spike(self): + """Test that template has a spike at index 1.""" + p = 5 + tpl = template.one_hot(p) + + zeroth_freq = 10 / p + expected_spike_value = 10 - zeroth_freq + + np.testing.assert_allclose(tpl[1], expected_spike_value, rtol=1e-5) + + +class TestFixedCn: + """Tests for template.fixed_cn function.""" + + def test_output_shape(self): + """Test that output shape is correct.""" + group_size = 8 + fourier_coef_mags = [0, 5, 3, 2, 1] + + tpl = template.fixed_cn(group_size, fourier_coef_mags) + + assert tpl.shape == (group_size,), f"Expected shape ({group_size},), got {tpl.shape}" + + def test_mean_centered(self): + """Test that the template is mean-centered.""" + group_size = 10 + fourier_coef_mags = [0, 5, 3, 2] + + tpl = template.fixed_cn(group_size, fourier_coef_mags) + + np.testing.assert_allclose(tpl.mean(), 0, atol=1e-10) + + def test_real_valued(self): + """Test that the template is real-valued.""" + group_size = 8 + fourier_coef_mags = [0, 5, 3] + + tpl = template.fixed_cn(group_size, fourier_coef_mags) + + assert np.isreal(tpl).all() + + +class TestFixedCnxcn: + """Tests for template.fixed_cnxcn function.""" + + def test_output_shape(self): + """Test that output shape is correct (flattened).""" + image_length = 6 + fourier_coef_mags = [0, 5, 3, 2] + + tpl = template.fixed_cnxcn(image_length, fourier_coef_mags) + + expected_size = image_length * image_length + assert tpl.shape == (expected_size,), f"Expected shape ({expected_size},), got {tpl.shape}" + + def test_mean_centered(self): + """Test that the template is mean-centered.""" + image_length = 5 + fourier_coef_mags = [0, 5, 3] + + tpl = template.fixed_cnxcn(image_length, fourier_coef_mags) + + np.testing.assert_allclose(tpl.mean(), 0, atol=1e-10) + + def test_real_valued(self): + """Test that the template is real-valued.""" + image_length = 4 + fourier_coef_mags = [0, 5] + + tpl = template.fixed_cnxcn(image_length, fourier_coef_mags) + + assert np.isreal(tpl).all() + + +class TestFixedGroup: + """Tests for template.fixed_group function.""" + + @pytest.fixture + def dihedral_group(self): + """Create a DihedralGroup for testing.""" + from escnn.group import DihedralGroup + + return DihedralGroup(N=3) + + def test_output_shape(self, dihedral_group): + """Test that output shape matches group order.""" + group_order = dihedral_group.order() + num_irreps = len(list(dihedral_group.irreps())) + fourier_coef_diag_values = [1.0] * num_irreps + + tpl = template.fixed_group(dihedral_group, fourier_coef_diag_values) + + assert tpl.shape == (group_order,), f"Expected shape ({group_order},), got {tpl.shape}" + + def test_mean_centered(self, dihedral_group): + """Test that the template is mean-centered.""" + num_irreps = len(list(dihedral_group.irreps())) + fourier_coef_diag_values = [1.0] * num_irreps + + tpl = template.fixed_group(dihedral_group, fourier_coef_diag_values) + + np.testing.assert_allclose(tpl.mean(), 0, atol=1e-10) + + def test_wrong_num_coefs_error(self, dihedral_group): + """Test that mismatched number of coefficients raises error.""" + wrong_num_coefs = [1.0, 2.0] + + with pytest.raises(AssertionError): + template.fixed_group(dihedral_group, wrong_num_coefs) + + def test_real_valued(self, dihedral_group): + """Test that the template is real-valued.""" + num_irreps = len(list(dihedral_group.irreps())) + fourier_coef_diag_values = [1.0] * num_irreps + + tpl = template.fixed_group(dihedral_group, fourier_coef_diag_values) + + assert np.isreal(tpl).all() From 04671d1973201f0727df84e9ceeb384cd0c3f64c Mon Sep 17 00:00:00 2001 From: Nina Miolane Date: Fri, 6 Feb 2026 22:11:46 +0000 Subject: [PATCH 2/2] fix install in ci --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 137106c..4e0a858 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.1.0" description = "Learning" authors = ["John Smith "] readme = "README.md" +packages = [{include = "src"}] [tool.poetry.dependencies] python = "^3.12"