A modular library for spectral transformer implementations in PyTorch. Replaces traditional attention mechanisms with Fourier transforms, wavelets, and other spectral methods.
- Modular Design: Mix and match components to create custom architectures
- Multiple Spectral Methods: FFT, DCT, DWT, Hadamard transforms, and more
- Efficient: Fast token mixing via frequency domain operations
- Type-Safe: Full type hints with Python 3.13+ support
- Well-Tested: Comprehensive test coverage
- Easy to Use: Consistent API across all models
pip install spectransFor development:
git clone https://github.com/aaronstevenwhite/spectrans.git
cd spectrans
pip install -e ".[dev]"Note: Windows is not currently supported. Please use Linux or macOS.
import torch
from spectrans.models import FNet
# Create FNet model for classification
model = FNet(
vocab_size=30000,
hidden_dim=768,
num_layers=12,
max_sequence_length=512,
num_classes=2
)
# Forward pass with token IDs
input_ids = torch.randint(0, 30000, (2, 128)) # (batch, seq_len)
logits = model(input_ids=input_ids)
print(f"Output shape: {logits.shape}") # torch.Size([2, 2])
# Or with embeddings directly
embeddings = torch.randn(2, 128, 768) # (batch, seq_len, hidden_dim)
logits = model(inputs_embeds=embeddings)| Model | Description | Key Operation |
|---|---|---|
FNet |
Token mixing via 2D Fourier transforms | FFT2D(tokens × features) |
GFNet |
Learnable frequency domain filters | FFT → element-wise multiply → iFFT |
AFNO |
Adaptive Fourier neural operators | FFT → keep top-k modes → MLP → iFFT |
WaveletTransformer |
Multi-resolution wavelet decomposition | DWT → process scales → iDWT |
SpectralAttention |
Attention via random Fourier features | φ(Q)φ(K)ᵀV where φ = RFF |
LSTTransformer |
Low-rank spectral approximation | DCT → low-rank projection → iDCT |
FNOTransformer |
Spectral convolution operators | FFT → spectral conv → iFFT + residual |
HybridTransformer |
Alternating spectral and attention layers | [Spectral, Attention, Spectral, ...] |
import torch
import torch.nn as nn
from torch.optim import AdamW
from spectrans.models import FNet
model = FNet(vocab_size=30000, hidden_dim=256, num_layers=6,
max_sequence_length=128, num_classes=2)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(3):
input_ids = torch.randint(0, 30000, (8, 128))
labels = torch.randint(0, 2, (8,))
logits = model(input_ids=input_ids)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")from spectrans.models import GFNet, AFNOModel, WaveletTransformer
# Global Filter Network with learnable filters
gfnet = GFNet(vocab_size=30000, hidden_dim=512, num_layers=8,
max_sequence_length=256, num_classes=10)
# Adaptive Fourier Neural Operator
afno = AFNOModel(vocab_size=30000, hidden_dim=512, num_layers=8,
max_sequence_length=256, modes_seq=32, num_classes=10)
# Wavelet Transformer
wavelet = WaveletTransformer(vocab_size=30000, hidden_dim=512,
num_layers=8, wavelet="db4", levels=3,
max_sequence_length=256, num_classes=10)
# All models share the same interface
input_ids = torch.randint(0, 30000, (4, 256))
output = gfnet(input_ids=input_ids) # Shape: (4, 10)from spectrans.models import HybridTransformer
# Alternate between spectral and attention layers
hybrid = HybridTransformer(
vocab_size=30000,
hidden_dim=768,
num_layers=12,
spectral_type="fourier",
spatial_type="attention",
alternation_pattern="even_spectral", # Even layers use spectral
num_heads=8,
max_sequence_length=512,
num_classes=2
)
output = hybrid(input_ids=input_ids)from spectrans.config import ConfigBuilder
# Load model from YAML
builder = ConfigBuilder()
model = builder.build_model("examples/configs/fnet.yaml")
# Or create programmatically
from spectrans.config.models import FNetModelConfig
from spectrans.config import build_model_from_config
config = FNetModelConfig(hidden_dim=512, num_layers=10,
sequence_length=128, vocab_size=8000,
num_classes=3)
model = build_model_from_config({"model": config.model_dump()})import torch
from spectrans.layers.mixing.base import MixingLayer
from spectrans import register_component
@register_component("mixing", "my_custom_mixing")
class MyCustomMixing(MixingLayer):
def __init__(self, hidden_dim: int):
super().__init__(hidden_dim=hidden_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Your implementation here
return x
def get_spectral_properties(self) -> dict[str, str | bool]:
"""Return spectral properties of this layer."""
return {
"transform_type": "identity",
"preserves_energy": True,
}
@property
def complexity(self) -> dict[str, str]:
return {"time": "O(n)", "space": "O(1)"}
# Use the custom component
custom_layer = MyCustomMixing(hidden_dim=768)
x = torch.randn(2, 128, 768)
output = custom_layer(x)- Full Documentation: https://spectrans.readthedocs.io
- Examples: See the
examples/directory for complete working examples - API Reference: Available in the documentation
We welcome contributions! Please see our Contributing Guide for details.
If you use Spectrans in your research, please cite:
@software{spectrans,
title = {spectrans: Modular Spectral Transformers in PyTorch},
author = {Aaron Steven White},
year = {2025},
url = {https://github.com/aaronstevenwhite/spectrans},
doi = {10.5281/zenodo.17171169}
}See LICENSE for details.