Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/make_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- run: pip install pdoc
# ADJUST THIS: build your documentation into docs/.
# We use a custom build script for pdoc itself, ideally you just run `pdoc -o docs/ ...` here.
- run: pdoc memorax -o docs --math
- run: pdoc memax -o docs --math

- uses: actions/upload-pages-artifact@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
environment:
name: pypi
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
# url: https://pypi.org/p/memorax
# url: https://pypi.org/p/memax
#
# ALTERNATIVE: if your GitHub Release name is the PyPI project version string
# ALTERNATIVE: exactly, uncomment the following line instead:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
memorax.egg-info/
memax.egg-info/
dist/
build/
*.pyc
Expand Down
85 changes: 45 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Memorax - Sequence and Memory Modeling in JAX
# Memax - Sequence and Memory Modeling in JAX

[![Tests](https://github.com/smorad/memorax/actions/workflows/python_app.yaml/badge.svg)](https://github.com/smorad/memorax/actions/workflows/python_app.yaml)
[![Tests](https://github.com/smorad/memax/actions/workflows/python_app.yaml/badge.svg)](https://github.com/smorad/memax/actions/workflows/python_app.yaml)

Memorax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memorax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding.
Memax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding.

## Table of Contents
1. [Models](#recurrent-models)
Expand All @@ -16,56 +16,61 @@ We implement both linear and log-complexity recurrent models.

| Name | Parallel Time Complexity | Paper | Code |
|------|--------------------------|-------|------|
| Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memorax/equinox/semigroups/lru.py) |
| Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6.py) |
| Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memorax/equinox/semigroups/lrnn.py) |
| Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memorax/equinox/semigroups/fart.py) |
| Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memorax/equinox/semigroups/ffm.py) |
| Rotational RNN (RotRNN) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2407.07239) | [[code]](memorax/equinox/semigroups/spherical.py) |
| Fast Weight Programmer | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2508.08435) | [[code]](memorax/equinox/semigroups/fwp.py) |
| DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memorax/equinox/semigroups/delta.py) |
| Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memorax/equinox/semigroups/gdn.py) |
| DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memorax/equinox/semigroups/deltap.py) |
| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) |
| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memorax/equinox/semigroups/attn.py) |
| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memorax/equinox/semigroups/attn.py) |
| Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memorax/equinox/set_actions/elman.py) |
| Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memorax/equinox/set_actions/gru.py) |
| Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memorax/equinox/set_actions/mgu.py) |
| Long Short-Term Memory Unit | $O(n)$ | [[paper]](https://ieeexplore.ieee.org/abstract/document/6795963) | [[code]](memorax/equinox/set_actions/lstm.py) |
| Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memax/equinox/semigroups/lru.py) |
| Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memax/equinox/semigroups/s6.py) |
| Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memax/equinox/semigroups/lrnn.py) |
| Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memax/equinox/semigroups/fart.py) |
| Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memax/equinox/semigroups/ffm.py) |
| Rotational RNN (RotRNN) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2407.07239) | [[code]](memax/equinox/semigroups/spherical.py) |
| Fast Weight Programmer | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2508.08435) | [[code]](memax/equinox/semigroups/fwp.py) |
| DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memax/equinox/semigroups/delta.py) |
| Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memax/equinox/semigroups/gdn.py) |
| DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memax/equinox/semigroups/deltap.py) |
| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memax/equinox/semigroups/attn.py) |
| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memax/equinox/semigroups/attn.py) |
| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memax/equinox/semigroups/attn.py) |
| Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memax/equinox/set_actions/elman.py) |
| Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memax/equinox/set_actions/gru.py) |
| Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memax/equinox/set_actions/mgu.py) |
| Long Short-Term Memory Unit | $O(n)$ | [[paper]](https://ieeexplore.ieee.org/abstract/document/6795963) | [[code]](memax/equinox/set_actions/lstm.py) |

# Datasets
We provide [datasets](memorax/datasets) to test our recurrent models.
We provide [datasets](memax/datasets) to test our recurrent models.

### Sequential MNIST [[HuggingFace]](https://huggingface.co/datasets/ylecun/mnist) [[Code]](memorax/datasets/sequential_mnist.py)
### Sequential MNIST [[HuggingFace]](https://huggingface.co/datasets/ylecun/mnist) [[Code]](memax/datasets/sequential_mnist.py)
> The recurrent model receives an MNIST image pixel by pixel, and must predict the digit class.
>
> **Sequence Lengths:** `[784]`

### MNIST Math [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fmnist-math) [[Code]](memorax/datasets/sequential_mnist.py)
### MNIST Math [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fmnist-math) [[Code]](memax/datasets/sequential_mnist.py)
> The recurrent model receives a sequence of MNIST images and operators, pixel by pixel, and must predict the percentile of the operators applied to the MNIST image classes.
>
> **Sequence Lengths:** `[784 * 5, 784 * 100, 784 * 1_000, 784 * 10_000, 784 * 1_000_000]`

### Continuous Localization [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fcontinuous-localization) [[Code]](memorax/datasets/sequential_mnist.py)
### Continuous Localization [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fcontinuous-localization) [[Code]](memax/datasets/sequential_mnist.py)
> The recurrent model receives a sequence of translation and rotation vectors **in the local coordinate frame**, and must predict the corresponding position and orientation **in the global coordinate frame**.
>
> **Sequence Lengths:** `[20, 100, 1_000]`

# Getting Started
Install `memorax` using pip and git for your specific framework
Install `memax` using pip and git for your specific framework
```bash
pip install "memorax[equinox]@git+https://github.com/smorad/memorax"
pip install "memorax[flax]@git+https://github.com/smorad/memorax"
pip install "memax[equinox]@git+https://github.com/smorad/memax"
pip install "memax[flax]@git+https://github.com/smorad/memax"
```
If you want to use our dataset and training scripts, install via
```bash
pip install "memax[dataset,equinox]@git+https://github.com/smorad/memax"
pip install "memax[dataset,flax]@git+https://github.com/smorad/memax"
```

## Equinox Quickstart
```python
from memorax.equinox.train_utils import get_residual_memory_models
from memax.equinox.train_utils import get_residual_memory_models
import jax
import jax.numpy as jnp
from equinox import filter_jit, filter_vmap
from memorax.equinox.train_utils import add_batch_dim
from memax.equinox.train_utils import add_batch_dim

T, F = 5, 6 # time and feature dim

Expand Down Expand Up @@ -96,17 +101,17 @@ python run_linen_experiments.py # flax linen framework


## Custom Architectures
Memorax uses the [`equinox`](https://github.com/patrick-kidger/equinox) neural network library. See [the semigroups directory](memorax/equinox/semigroups) for fast recurrent models that utilize an associative scan. We also provide a beta [`flax.linen`](https://flax-linen.readthedocs.io/en/latest/) API. In this example, we focus on `equinox`.
memax uses the [`equinox`](https://github.com/patrick-kidger/equinox) neural network library. See [the semigroups directory](memax/equinox/semigroups) for fast recurrent models that utilize an associative scan. We also provide a beta [`flax.linen`](https://flax-linen.readthedocs.io/en/latest/) API. In this example, we focus on `equinox`.

```python
import equinox as eqx
import jax
import jax.numpy as jnp

from memorax.equinox.set_actions.gru import GRU
from memorax.equinox.models.residual import ResidualModel
from memorax.equinox.semigroups.lru import LRU, LRUSemigroup
from memorax.utils import debug_shape
from memax.equinox.set_actions.gru import GRU
from memax.equinox.models.residual import ResidualModel
from memax.equinox.semigroups.lru import LRU, LRUSemigroup
from memax.utils import debug_shape

# You can pack multiple subsequences into a single sequence using the start flag
sequence_starts = jnp.array([True, False, False, True, False])
Expand Down Expand Up @@ -174,19 +179,19 @@ h, y = eqx.filter_jit(model)(latest_h, inputs)
```

## Creating Custom Recurrent Models
All recurrent cells should follow the [`GRAS`](memorax/equinox/gras.py) interface. A recurrent cell consists of an `Algebra`. You can roughly think of the `Algebra` as the function that updates the recurrent state, and the `GRAS` as the `Algebra` and all the associated MLPs/gates. You may reuse our `Algebra`s in your custom `GRAS`, or even write your custom `Algebra`.
All recurrent cells should follow the [`GRAS`](memax/equinox/gras.py) interface. A recurrent cell consists of an `Algebra`. You can roughly think of the `Algebra` as the function that updates the recurrent state, and the `GRAS` as the `Algebra` and all the associated MLPs/gates. You may reuse our `Algebra`s in your custom `GRAS`, or even write your custom `Algebra`.

To implement your own `Algebra` and `GRAS`, we suggest copying one from our existing code, such as the [LRNN](memorax/equinox/semigroups/lrnn.py) for a `Semigroup` or the [Elman Network](memorax/equinox/set_actions/elman.py) for a `SetAction`.
To implement your own `Algebra` and `GRAS`, we suggest copying one from our existing code, such as the [LRNN](memax/equinox/semigroups/lrnn.py) for a `Semigroup` or the [Elman Network](memax/equinox/set_actions/elman.py) for a `SetAction`.

# Documentation
Full documentation is available [here](https://smorad.github.io/memorax/memorax.html).
Full documentation is available [here](https://smorad.github.io/memax/memax.html).

# Citing our Work
Please cite the library as
```
@misc{morad_memorax_2025,
title = {Memorax},
url = {https://github.com/smorad/memorax},
@misc{morad_memax_2025,
title = {memax},
url = {https://github.com/smorad/memax},
author = {Morad, Steven and Toledo, Edan and Kortvelesy, Ryan and He, Zhe},
month = jun,
year = {2025},
Expand Down
4 changes: 2 additions & 2 deletions memorax/__init__.py → memax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Memorax: A library for memory-augmented neural networks.
memax: A library for memory-augmented neural networks.

We provide backends for both Equinox and Flax.

We utilize a Generalized Recurrent Algebraic Structure (GRAS) framework
to define and implement virtually any memory model we come across using a single interface.
Take a look at `memorax.equinox.gras` or `memorax.linen.gras` for more details.
Take a look at `memax.equinox.gras` or `memax.linen.gras` for more details.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from datasets import load_dataset # huggingface datasets

from memorax.train_utils import get_residual_memory_models
from memax.train_utils import get_residual_memory_models


NUM_EPOCHS = 100
Expand Down
File renamed without changes.
8 changes: 8 additions & 0 deletions memax/equinox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Equinox backend for memax.

`memax.equinox.gras` provides the Generalized Recurrent Algebraic Structure (GRAS) base module.
`memax.equinox.groups` provides the algebraic structures (semigroups, groups, etc.) used in GRAS.
`memax.equinox.set_actions` provides set action-based recurrent layers (slow RNNs).
`memax.equinox.semigroups` provides semigroup-based recurrent layers (fast RNNs).
`memax.equinox.scans` provides scan functions for recurrent updates.
"""
4 changes: 2 additions & 2 deletions memorax/equinox/gras.py → memax/equinox/gras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import jax
from jaxtyping import PRNGKeyArray, Shaped

from memorax.equinox.groups import BinaryAlgebra, Module
from memorax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState
from memax.equinox.groups import BinaryAlgebra, Module
from memax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState


class GRAS(Module):
Expand Down
4 changes: 2 additions & 2 deletions memorax/equinox/groups.py → memax/equinox/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray, Shaped

from memorax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag
from memorax.utils import debug_shape
from memax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag
from memax.utils import debug_shape


class Module(eqx.Module):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from equinox import filter_vmap, nn
from jaxtyping import PRNGKeyArray, Shaped

from memorax.equinox.groups import Module
from memorax.mtypes import Input, ResetRecurrentState
from memax.equinox.groups import Module
from memax.mtypes import Input, ResetRecurrentState


class ResidualModel(Module):
Expand Down
4 changes: 2 additions & 2 deletions memorax/equinox/scans.py → memax/equinox/scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import jax
import jax.numpy as jnp

from memorax.mtypes import RecurrentState
from memorax.utils import debug_shape
from memax.mtypes import RecurrentState
from memax.utils import debug_shape



Expand Down
18 changes: 18 additions & 0 deletions memax/equinox/semigroups/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""This module contains semigroup-based recurrent layers.
They are generally much faster than standard RNNs.

Each RNN type gets its own file.
+ `memax.equinox.semigroups.attn` provides dot-product attention layer.
+ `memax.equinox.semigroups.delta` provides the DeltaNet layer.
+ `memax.equinox.semigroups.deltap` provides the DeltaProduct layer.
+ `memax.equinox.semigroups.stack` provides framestacking (sliding-window) as an RNN.
+ `memax.equinox.semigroups.fart` provides the Fast AutoRegressive Transformer layer.
+ `memax.equinox.semigroups.ffm` provides the Fast and Forgetful Memory layer.
+ `memax.equinox.semigroups.fwp` provides the Fast Weight Programmer layer.
+ `memax.equinox.semigroups.lru` provides the Linear Recurrent Unit layer.
+ `memax.equinox.semigroups.gdn` provides the Gated DeltaNet layer.
+ `memax.equinox.semigroups.lrnn` provides a basic linear recurrence.
+ `memax.equinox.semigroups.mlp` provides an MLP (no memory) for completeness.
+ `memax.equinox.semigroups.s6` provides the Selective State Space Model (Mamba) layer.
+ `memax.equinox.semigroups.spherical` provides Rotational RNN layer (spherical projection).
"""
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool, Int

from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memorax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align
from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memax.equinox.gras import GRAS
from memax.mtypes import Input, StartFlag
from memax.equinox.scans import semigroup_scan
from memax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align

AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Window"], Int[Array, "Window"]]
AttentionRecurrentStateWithReset = Tuple[AttentionRecurrentState, StartFlag]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped

from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memax.equinox.gras import GRAS
from memax.mtypes import Input, StartFlag
from memax.equinox.scans import semigroup_scan

DeltaFWPRecurrentState = Tuple[
Float[Array, "Key Value"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped

from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memax.equinox.gras import GRAS
from memax.mtypes import Input, StartFlag
from memax.equinox.scans import semigroup_scan

DeltaProductRecurrentState = Tuple[
Float[Array, "Key Value"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from equinox import filter_vmap, nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped

from memorax.equinox.groups import BinaryAlgebra, Module, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memax.equinox.groups import BinaryAlgebra, Module, Semigroup, Resettable
from memax.equinox.gras import GRAS
from memax.mtypes import Input, StartFlag
from memax.equinox.scans import semigroup_scan

DLSERecurrentState = Float[Array, "Hidden Hidden"]
DLSERecurrentStateWithReset = Tuple[DLSERecurrentState, StartFlag]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped

from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memax.equinox.gras import GRAS
from memax.mtypes import Input, StartFlag
from memax.equinox.scans import semigroup_scan

FARTRecurrentState = Tuple[Float[Array, "Key Value"], Float[Array, "Key"]]
FARTRecurrentStateWithReset = Tuple[FARTRecurrentState, StartFlag]
Expand Down
Loading