From f6c6feb587e19639e3b48d1eb2056e7cd3db3f34 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Wed, 3 Dec 2025 17:53:36 +0800 Subject: [PATCH 1/2] Memorax -> memax due to pypi squatter --- .github/workflows/make_docs.yaml | 2 +- .github/workflows/python-publish.yml | 2 +- .gitignore | 2 +- README.md | 80 +++++++++---------- {memorax => memax}/__init__.py | 4 +- .../datasets/continuous_localization.py | 0 {memorax => memax}/datasets/mnist_listops.py | 2 +- {memorax => memax}/datasets/mnist_math.py | 0 .../datasets/sequential_mnist.py | 0 memax/equinox/__init__.py | 8 ++ {memorax => memax}/equinox/gras.py | 4 +- {memorax => memax}/equinox/groups.py | 4 +- {memorax => memax}/equinox/models/__init__.py | 0 {memorax => memax}/equinox/models/residual.py | 4 +- {memorax => memax}/equinox/scans.py | 4 +- memax/equinox/semigroups/__init__.py | 18 +++++ {memorax => memax}/equinox/semigroups/attn.py | 10 +-- .../equinox/semigroups/delta.py | 8 +- .../equinox/semigroups/deltap.py | 8 +- {memorax => memax}/equinox/semigroups/dlse.py | 8 +- {memorax => memax}/equinox/semigroups/fart.py | 8 +- {memorax => memax}/equinox/semigroups/ffm.py | 8 +- {memorax => memax}/equinox/semigroups/fwp.py | 8 +- {memorax => memax}/equinox/semigroups/gdn.py | 8 +- {memorax => memax}/equinox/semigroups/lrnn.py | 8 +- {memorax => memax}/equinox/semigroups/lru.py | 8 +- {memorax => memax}/equinox/semigroups/mlp.py | 8 +- {memorax => memax}/equinox/semigroups/nmax.py | 8 +- {memorax => memax}/equinox/semigroups/s6.py | 8 +- .../equinox/semigroups/spherical.py | 8 +- .../equinox/semigroups/stack.py | 10 +-- memax/equinox/set_actions/__init__.py | 9 +++ .../equinox/set_actions/elman.py | 8 +- {memorax => memax}/equinox/set_actions/gru.py | 8 +- .../equinox/set_actions/lstm.py | 8 +- {memorax => memax}/equinox/set_actions/mgu.py | 8 +- .../equinox/set_actions/spherical.py | 8 +- {memorax => memax}/equinox/train_utils.py | 44 +++++----- memax/linen/__init__.py | 8 ++ {memorax => memax}/linen/gras.py | 4 +- {memorax => memax}/linen/groups.py | 4 +- {memorax => memax}/linen/models/__init__.py | 0 {memorax => memax}/linen/models/residual.py | 4 +- {memorax => memax}/linen/scans.py | 4 +- .../linen/semigroups/__init__.py | 0 {memorax => memax}/linen/semigroups/fart.py | 8 +- {memorax => memax}/linen/semigroups/lru.py | 8 +- {memorax => memax}/linen/semigroups/s6.py | 8 +- .../linen/set_actions/__init__.py | 2 +- {memorax => memax}/linen/set_actions/gru.py | 8 +- {memorax => memax}/linen/train_utils.py | 10 +-- {memorax => memax}/mtypes.py | 0 {memorax => memax}/utils.py | 2 +- memorax/equinox/__init__.py | 8 -- memorax/equinox/semigroups/__init__.py | 18 ----- memorax/equinox/set_actions/__init__.py | 9 --- memorax/linen/__init__.py | 8 -- run_equinox_experiments.py | 12 +-- run_linen_experiments.py | 10 +-- setup.py | 4 +- tests/test_associative_equinox.py | 4 +- tests/test_associative_linen.py | 4 +- tests/test_continuous_localization.py | 2 +- tests/test_initial_input_equinox.py | 2 +- tests/test_initial_input_linen.py | 2 +- tests/test_readme.py | 12 +-- tests/test_reset_equinox.py | 2 +- tests/test_reset_linen.py | 2 +- tests/test_stack_equinox.py | 2 +- 69 files changed, 261 insertions(+), 261 deletions(-) rename {memorax => memax}/__init__.py (61%) rename {memorax => memax}/datasets/continuous_localization.py (100%) rename {memorax => memax}/datasets/mnist_listops.py (98%) rename {memorax => memax}/datasets/mnist_math.py (100%) rename {memorax => memax}/datasets/sequential_mnist.py (100%) create mode 100644 memax/equinox/__init__.py rename {memorax => memax}/equinox/gras.py (96%) rename {memorax => memax}/equinox/groups.py (97%) rename {memorax => memax}/equinox/models/__init__.py (100%) rename {memorax => memax}/equinox/models/residual.py (96%) rename {memorax => memax}/equinox/scans.py (97%) create mode 100644 memax/equinox/semigroups/__init__.py rename {memorax => memax}/equinox/semigroups/attn.py (95%) rename {memorax => memax}/equinox/semigroups/delta.py (95%) rename {memorax => memax}/equinox/semigroups/deltap.py (96%) rename {memorax => memax}/equinox/semigroups/dlse.py (93%) rename {memorax => memax}/equinox/semigroups/fart.py (94%) rename {memorax => memax}/equinox/semigroups/ffm.py (96%) rename {memorax => memax}/equinox/semigroups/fwp.py (94%) rename {memorax => memax}/equinox/semigroups/gdn.py (95%) rename {memorax => memax}/equinox/semigroups/lrnn.py (93%) rename {memorax => memax}/equinox/semigroups/lru.py (96%) rename {memorax => memax}/equinox/semigroups/mlp.py (92%) rename {memorax => memax}/equinox/semigroups/nmax.py (93%) rename {memorax => memax}/equinox/semigroups/s6.py (95%) rename {memorax => memax}/equinox/semigroups/spherical.py (94%) rename {memorax => memax}/equinox/semigroups/stack.py (94%) create mode 100644 memax/equinox/set_actions/__init__.py rename {memorax => memax}/equinox/set_actions/elman.py (93%) rename {memorax => memax}/equinox/set_actions/gru.py (94%) rename {memorax => memax}/equinox/set_actions/lstm.py (94%) rename {memorax => memax}/equinox/set_actions/mgu.py (93%) rename {memorax => memax}/equinox/set_actions/spherical.py (94%) rename {memorax => memax}/equinox/train_utils.py (90%) create mode 100644 memax/linen/__init__.py rename {memorax => memax}/linen/gras.py (96%) rename {memorax => memax}/linen/groups.py (98%) rename {memorax => memax}/linen/models/__init__.py (100%) rename {memorax => memax}/linen/models/residual.py (96%) rename {memorax => memax}/linen/scans.py (97%) rename {memorax => memax}/linen/semigroups/__init__.py (100%) rename {memorax => memax}/linen/semigroups/fart.py (94%) rename {memorax => memax}/linen/semigroups/lru.py (97%) rename {memorax => memax}/linen/semigroups/s6.py (95%) rename {memorax => memax}/linen/set_actions/__init__.py (58%) rename {memorax => memax}/linen/set_actions/gru.py (93%) rename {memorax => memax}/linen/train_utils.py (94%) rename {memorax => memax}/mtypes.py (100%) rename {memorax => memax}/utils.py (99%) delete mode 100644 memorax/equinox/__init__.py delete mode 100644 memorax/equinox/semigroups/__init__.py delete mode 100644 memorax/equinox/set_actions/__init__.py delete mode 100644 memorax/linen/__init__.py diff --git a/.github/workflows/make_docs.yaml b/.github/workflows/make_docs.yaml index 1179304..1cda425 100644 --- a/.github/workflows/make_docs.yaml +++ b/.github/workflows/make_docs.yaml @@ -31,7 +31,7 @@ jobs: - run: pip install pdoc # ADJUST THIS: build your documentation into docs/. # We use a custom build script for pdoc itself, ideally you just run `pdoc -o docs/ ...` here. - - run: pdoc memorax -o docs --math + - run: pdoc memax -o docs --math - uses: actions/upload-pages-artifact@v4 with: diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 87d1713..032d3b8 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -51,7 +51,7 @@ jobs: environment: name: pypi # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: - # url: https://pypi.org/p/memorax + # url: https://pypi.org/p/memax # # ALTERNATIVE: if your GitHub Release name is the PyPI project version string # ALTERNATIVE: exactly, uncomment the following line instead: diff --git a/.gitignore b/.gitignore index 1fd6ac6..6bb168e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -memorax.egg-info/ +memax.egg-info/ dist/ build/ *.pyc diff --git a/README.md b/README.md index 9a3dd55..6aace0a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# Memorax - Sequence and Memory Modeling in JAX +# memax - Sequence and Memory Modeling in JAX -[![Tests](https://github.com/smorad/memorax/actions/workflows/python_app.yaml/badge.svg)](https://github.com/smorad/memorax/actions/workflows/python_app.yaml) +[![Tests](https://github.com/smorad/memax/actions/workflows/python_app.yaml/badge.svg)](https://github.com/smorad/memax/actions/workflows/python_app.yaml) -Memorax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memorax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding. +memax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding. ## Table of Contents 1. [Models](#recurrent-models) @@ -16,56 +16,56 @@ We implement both linear and log-complexity recurrent models. | Name | Parallel Time Complexity | Paper | Code | |------|--------------------------|-------|------| -| Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memorax/equinox/semigroups/lru.py) | -| Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6.py) | -| Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memorax/equinox/semigroups/lrnn.py) | -| Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memorax/equinox/semigroups/fart.py) | -| Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memorax/equinox/semigroups/ffm.py) | -| Rotational RNN (RotRNN) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2407.07239) | [[code]](memorax/equinox/semigroups/spherical.py) | -| Fast Weight Programmer | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2508.08435) | [[code]](memorax/equinox/semigroups/fwp.py) | -| DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memorax/equinox/semigroups/delta.py) | -| Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memorax/equinox/semigroups/gdn.py) | -| DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memorax/equinox/semigroups/deltap.py) | -| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) | -| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memorax/equinox/semigroups/attn.py) | -| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memorax/equinox/semigroups/attn.py) | -| Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memorax/equinox/set_actions/elman.py) | -| Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memorax/equinox/set_actions/gru.py) | -| Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memorax/equinox/set_actions/mgu.py) | -| Long Short-Term Memory Unit | $O(n)$ | [[paper]](https://ieeexplore.ieee.org/abstract/document/6795963) | [[code]](memorax/equinox/set_actions/lstm.py) | +| Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memax/equinox/semigroups/lru.py) | +| Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memax/equinox/semigroups/s6.py) | +| Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memax/equinox/semigroups/lrnn.py) | +| Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memax/equinox/semigroups/fart.py) | +| Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memax/equinox/semigroups/ffm.py) | +| Rotational RNN (RotRNN) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2407.07239) | [[code]](memax/equinox/semigroups/spherical.py) | +| Fast Weight Programmer | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2508.08435) | [[code]](memax/equinox/semigroups/fwp.py) | +| DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memax/equinox/semigroups/delta.py) | +| Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memax/equinox/semigroups/gdn.py) | +| DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memax/equinox/semigroups/deltap.py) | +| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memax/equinox/semigroups/attn.py) | +| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memax/equinox/semigroups/attn.py) | +| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memax/equinox/semigroups/attn.py) | +| Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memax/equinox/set_actions/elman.py) | +| Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memax/equinox/set_actions/gru.py) | +| Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memax/equinox/set_actions/mgu.py) | +| Long Short-Term Memory Unit | $O(n)$ | [[paper]](https://ieeexplore.ieee.org/abstract/document/6795963) | [[code]](memax/equinox/set_actions/lstm.py) | # Datasets -We provide [datasets](memorax/datasets) to test our recurrent models. +We provide [datasets](memax/datasets) to test our recurrent models. -### Sequential MNIST [[HuggingFace]](https://huggingface.co/datasets/ylecun/mnist) [[Code]](memorax/datasets/sequential_mnist.py) +### Sequential MNIST [[HuggingFace]](https://huggingface.co/datasets/ylecun/mnist) [[Code]](memax/datasets/sequential_mnist.py) > The recurrent model receives an MNIST image pixel by pixel, and must predict the digit class. > > **Sequence Lengths:** `[784]` -### MNIST Math [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fmnist-math) [[Code]](memorax/datasets/sequential_mnist.py) +### MNIST Math [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fmnist-math) [[Code]](memax/datasets/sequential_mnist.py) > The recurrent model receives a sequence of MNIST images and operators, pixel by pixel, and must predict the percentile of the operators applied to the MNIST image classes. > > **Sequence Lengths:** `[784 * 5, 784 * 100, 784 * 1_000, 784 * 10_000, 784 * 1_000_000]` -### Continuous Localization [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fcontinuous-localization) [[Code]](memorax/datasets/sequential_mnist.py) +### Continuous Localization [[HuggingFace]](https://huggingface.co/datasets?sort=trending&search=bolt-lab%2Fcontinuous-localization) [[Code]](memax/datasets/sequential_mnist.py) > The recurrent model receives a sequence of translation and rotation vectors **in the local coordinate frame**, and must predict the corresponding position and orientation **in the global coordinate frame**. > > **Sequence Lengths:** `[20, 100, 1_000]` # Getting Started -Install `memorax` using pip and git for your specific framework +Install `memax` using pip and git for your specific framework ```bash -pip install "memorax[equinox]@git+https://github.com/smorad/memorax" -pip install "memorax[flax]@git+https://github.com/smorad/memorax" +pip install "memax[equinox]@git+https://github.com/smorad/memax" +pip install "memax[flax]@git+https://github.com/smorad/memax" ``` ## Equinox Quickstart ```python -from memorax.equinox.train_utils import get_residual_memory_models +from memax.equinox.train_utils import get_residual_memory_models import jax import jax.numpy as jnp from equinox import filter_jit, filter_vmap -from memorax.equinox.train_utils import add_batch_dim +from memax.equinox.train_utils import add_batch_dim T, F = 5, 6 # time and feature dim @@ -96,17 +96,17 @@ python run_linen_experiments.py # flax linen framework ## Custom Architectures -Memorax uses the [`equinox`](https://github.com/patrick-kidger/equinox) neural network library. See [the semigroups directory](memorax/equinox/semigroups) for fast recurrent models that utilize an associative scan. We also provide a beta [`flax.linen`](https://flax-linen.readthedocs.io/en/latest/) API. In this example, we focus on `equinox`. +memax uses the [`equinox`](https://github.com/patrick-kidger/equinox) neural network library. See [the semigroups directory](memax/equinox/semigroups) for fast recurrent models that utilize an associative scan. We also provide a beta [`flax.linen`](https://flax-linen.readthedocs.io/en/latest/) API. In this example, we focus on `equinox`. ```python import equinox as eqx import jax import jax.numpy as jnp -from memorax.equinox.set_actions.gru import GRU -from memorax.equinox.models.residual import ResidualModel -from memorax.equinox.semigroups.lru import LRU, LRUSemigroup -from memorax.utils import debug_shape +from memax.equinox.set_actions.gru import GRU +from memax.equinox.models.residual import ResidualModel +from memax.equinox.semigroups.lru import LRU, LRUSemigroup +from memax.utils import debug_shape # You can pack multiple subsequences into a single sequence using the start flag sequence_starts = jnp.array([True, False, False, True, False]) @@ -174,19 +174,19 @@ h, y = eqx.filter_jit(model)(latest_h, inputs) ``` ## Creating Custom Recurrent Models -All recurrent cells should follow the [`GRAS`](memorax/equinox/gras.py) interface. A recurrent cell consists of an `Algebra`. You can roughly think of the `Algebra` as the function that updates the recurrent state, and the `GRAS` as the `Algebra` and all the associated MLPs/gates. You may reuse our `Algebra`s in your custom `GRAS`, or even write your custom `Algebra`. +All recurrent cells should follow the [`GRAS`](memax/equinox/gras.py) interface. A recurrent cell consists of an `Algebra`. You can roughly think of the `Algebra` as the function that updates the recurrent state, and the `GRAS` as the `Algebra` and all the associated MLPs/gates. You may reuse our `Algebra`s in your custom `GRAS`, or even write your custom `Algebra`. -To implement your own `Algebra` and `GRAS`, we suggest copying one from our existing code, such as the [LRNN](memorax/equinox/semigroups/lrnn.py) for a `Semigroup` or the [Elman Network](memorax/equinox/set_actions/elman.py) for a `SetAction`. +To implement your own `Algebra` and `GRAS`, we suggest copying one from our existing code, such as the [LRNN](memax/equinox/semigroups/lrnn.py) for a `Semigroup` or the [Elman Network](memax/equinox/set_actions/elman.py) for a `SetAction`. # Documentation -Full documentation is available [here](https://smorad.github.io/memorax/memorax.html). +Full documentation is available [here](https://smorad.github.io/memax/memax.html). # Citing our Work Please cite the library as ``` -@misc{morad_memorax_2025, - title = {Memorax}, - url = {https://github.com/smorad/memorax}, +@misc{morad_memax_2025, + title = {memax}, + url = {https://github.com/smorad/memax}, author = {Morad, Steven and Toledo, Edan and Kortvelesy, Ryan and He, Zhe}, month = jun, year = {2025}, diff --git a/memorax/__init__.py b/memax/__init__.py similarity index 61% rename from memorax/__init__.py rename to memax/__init__.py index 7ab904f..bb68856 100644 --- a/memorax/__init__.py +++ b/memax/__init__.py @@ -1,9 +1,9 @@ """ -Memorax: A library for memory-augmented neural networks. +memax: A library for memory-augmented neural networks. We provide backends for both Equinox and Flax. We utilize a Generalized Recurrent Algebraic Structure (GRAS) framework to define and implement virtually any memory model we come across using a single interface. -Take a look at `memorax.equinox.gras` or `memorax.linen.gras` for more details. +Take a look at `memax.equinox.gras` or `memax.linen.gras` for more details. """ \ No newline at end of file diff --git a/memorax/datasets/continuous_localization.py b/memax/datasets/continuous_localization.py similarity index 100% rename from memorax/datasets/continuous_localization.py rename to memax/datasets/continuous_localization.py diff --git a/memorax/datasets/mnist_listops.py b/memax/datasets/mnist_listops.py similarity index 98% rename from memorax/datasets/mnist_listops.py rename to memax/datasets/mnist_listops.py index 0a45017..bcea732 100644 --- a/memorax/datasets/mnist_listops.py +++ b/memax/datasets/mnist_listops.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from datasets import load_dataset # huggingface datasets -from memorax.train_utils import get_residual_memory_models +from memax.train_utils import get_residual_memory_models NUM_EPOCHS = 100 diff --git a/memorax/datasets/mnist_math.py b/memax/datasets/mnist_math.py similarity index 100% rename from memorax/datasets/mnist_math.py rename to memax/datasets/mnist_math.py diff --git a/memorax/datasets/sequential_mnist.py b/memax/datasets/sequential_mnist.py similarity index 100% rename from memorax/datasets/sequential_mnist.py rename to memax/datasets/sequential_mnist.py diff --git a/memax/equinox/__init__.py b/memax/equinox/__init__.py new file mode 100644 index 0000000..485da5c --- /dev/null +++ b/memax/equinox/__init__.py @@ -0,0 +1,8 @@ +"""Equinox backend for memax. + +`memax.equinox.gras` provides the Generalized Recurrent Algebraic Structure (GRAS) base module. +`memax.equinox.groups` provides the algebraic structures (semigroups, groups, etc.) used in GRAS. +`memax.equinox.set_actions` provides set action-based recurrent layers (slow RNNs). +`memax.equinox.semigroups` provides semigroup-based recurrent layers (fast RNNs). +`memax.equinox.scans` provides scan functions for recurrent updates. +""" \ No newline at end of file diff --git a/memorax/equinox/gras.py b/memax/equinox/gras.py similarity index 96% rename from memorax/equinox/gras.py rename to memax/equinox/gras.py index 8042090..b352c4b 100644 --- a/memorax/equinox/gras.py +++ b/memax/equinox/gras.py @@ -4,8 +4,8 @@ import jax from jaxtyping import PRNGKeyArray, Shaped -from memorax.equinox.groups import BinaryAlgebra, Module -from memorax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState +from memax.equinox.groups import BinaryAlgebra, Module +from memax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState class GRAS(Module): diff --git a/memorax/equinox/groups.py b/memax/equinox/groups.py similarity index 97% rename from memorax/equinox/groups.py rename to memax/equinox/groups.py index e891dc0..df0201e 100644 --- a/memorax/equinox/groups.py +++ b/memax/equinox/groups.py @@ -6,8 +6,8 @@ import jax.numpy as jnp from jaxtyping import PRNGKeyArray, Shaped -from memorax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag -from memorax.utils import debug_shape +from memax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag +from memax.utils import debug_shape class Module(eqx.Module): diff --git a/memorax/equinox/models/__init__.py b/memax/equinox/models/__init__.py similarity index 100% rename from memorax/equinox/models/__init__.py rename to memax/equinox/models/__init__.py diff --git a/memorax/equinox/models/residual.py b/memax/equinox/models/residual.py similarity index 96% rename from memorax/equinox/models/residual.py rename to memax/equinox/models/residual.py index decb806..e89ad74 100644 --- a/memorax/equinox/models/residual.py +++ b/memax/equinox/models/residual.py @@ -4,8 +4,8 @@ from equinox import filter_vmap, nn from jaxtyping import PRNGKeyArray, Shaped -from memorax.equinox.groups import Module -from memorax.mtypes import Input, ResetRecurrentState +from memax.equinox.groups import Module +from memax.mtypes import Input, ResetRecurrentState class ResidualModel(Module): diff --git a/memorax/equinox/scans.py b/memax/equinox/scans.py similarity index 97% rename from memorax/equinox/scans.py rename to memax/equinox/scans.py index e4ea4b5..5bfef08 100644 --- a/memorax/equinox/scans.py +++ b/memax/equinox/scans.py @@ -9,8 +9,8 @@ import jax import jax.numpy as jnp -from memorax.mtypes import RecurrentState -from memorax.utils import debug_shape +from memax.mtypes import RecurrentState +from memax.utils import debug_shape diff --git a/memax/equinox/semigroups/__init__.py b/memax/equinox/semigroups/__init__.py new file mode 100644 index 0000000..b679ed2 --- /dev/null +++ b/memax/equinox/semigroups/__init__.py @@ -0,0 +1,18 @@ +"""This module contains semigroup-based recurrent layers. +They are generally much faster than standard RNNs. + +Each RNN type gets its own file. ++ `memax.equinox.semigroups.attn` provides dot-product attention layer. ++ `memax.equinox.semigroups.delta` provides the DeltaNet layer. ++ `memax.equinox.semigroups.deltap` provides the DeltaProduct layer. ++ `memax.equinox.semigroups.stack` provides framestacking (sliding-window) as an RNN. ++ `memax.equinox.semigroups.fart` provides the Fast AutoRegressive Transformer layer. ++ `memax.equinox.semigroups.ffm` provides the Fast and Forgetful Memory layer. ++ `memax.equinox.semigroups.fwp` provides the Fast Weight Programmer layer. ++ `memax.equinox.semigroups.lru` provides the Linear Recurrent Unit layer. ++ `memax.equinox.semigroups.gdn` provides the Gated DeltaNet layer. ++ `memax.equinox.semigroups.lrnn` provides a basic linear recurrence. ++ `memax.equinox.semigroups.mlp` provides an MLP (no memory) for completeness. ++ `memax.equinox.semigroups.s6` provides the Selective State Space Model (Mamba) layer. ++ `memax.equinox.semigroups.spherical` provides Rotational RNN layer (spherical projection). +""" \ No newline at end of file diff --git a/memorax/equinox/semigroups/attn.py b/memax/equinox/semigroups/attn.py similarity index 95% rename from memorax/equinox/semigroups/attn.py rename to memax/equinox/semigroups/attn.py index 3971abb..f5298c1 100644 --- a/memorax/equinox/semigroups/attn.py +++ b/memax/equinox/semigroups/attn.py @@ -6,11 +6,11 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool, Int -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan -from memorax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan +from memax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Window"], Int[Array, "Window"]] AttentionRecurrentStateWithReset = Tuple[AttentionRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/delta.py b/memax/equinox/semigroups/delta.py similarity index 95% rename from memorax/equinox/semigroups/delta.py rename to memax/equinox/semigroups/delta.py index 3317a3e..6f2286f 100644 --- a/memorax/equinox/semigroups/delta.py +++ b/memax/equinox/semigroups/delta.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan DeltaFWPRecurrentState = Tuple[ Float[Array, "Key Value"], diff --git a/memorax/equinox/semigroups/deltap.py b/memax/equinox/semigroups/deltap.py similarity index 96% rename from memorax/equinox/semigroups/deltap.py rename to memax/equinox/semigroups/deltap.py index df2855c..cd857fc 100644 --- a/memorax/equinox/semigroups/deltap.py +++ b/memax/equinox/semigroups/deltap.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan DeltaProductRecurrentState = Tuple[ Float[Array, "Key Value"], diff --git a/memorax/equinox/semigroups/dlse.py b/memax/equinox/semigroups/dlse.py similarity index 93% rename from memorax/equinox/semigroups/dlse.py rename to memax/equinox/semigroups/dlse.py index 30a422b..9ed5b7f 100644 --- a/memorax/equinox/semigroups/dlse.py +++ b/memax/equinox/semigroups/dlse.py @@ -6,10 +6,10 @@ from equinox import filter_vmap, nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Module, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Module, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan DLSERecurrentState = Float[Array, "Hidden Hidden"] DLSERecurrentStateWithReset = Tuple[DLSERecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/fart.py b/memax/equinox/semigroups/fart.py similarity index 94% rename from memorax/equinox/semigroups/fart.py rename to memax/equinox/semigroups/fart.py index 05bbbba..3d0997c 100644 --- a/memorax/equinox/semigroups/fart.py +++ b/memax/equinox/semigroups/fart.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan FARTRecurrentState = Tuple[Float[Array, "Key Value"], Float[Array, "Key"]] FARTRecurrentStateWithReset = Tuple[FARTRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/ffm.py b/memax/equinox/semigroups/ffm.py similarity index 96% rename from memorax/equinox/semigroups/ffm.py rename to memax/equinox/semigroups/ffm.py index c614f28..14d2402 100644 --- a/memorax/equinox/semigroups/ffm.py +++ b/memax/equinox/semigroups/ffm.py @@ -7,10 +7,10 @@ from equinox import nn from jaxtyping import Array, Complex, Float, Int, PRNGKeyArray, Real, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan FFMRecurrentState = Tuple[Complex[Array, "Trace Context"], Int[Array, ""]] FFMRecurrentStateWithReset = Tuple[FFMRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/fwp.py b/memax/equinox/semigroups/fwp.py similarity index 94% rename from memorax/equinox/semigroups/fwp.py rename to memax/equinox/semigroups/fwp.py index 6ad0c1c..d24c8b3 100644 --- a/memorax/equinox/semigroups/fwp.py +++ b/memax/equinox/semigroups/fwp.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan FWPRecurrentState = Float[Array, "Key Value"] FWPRecurrentStateWithReset = Tuple[FWPRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/gdn.py b/memax/equinox/semigroups/gdn.py similarity index 95% rename from memorax/equinox/semigroups/gdn.py rename to memax/equinox/semigroups/gdn.py index 263aa28..c5b7c89 100644 --- a/memorax/equinox/semigroups/gdn.py +++ b/memax/equinox/semigroups/gdn.py @@ -7,10 +7,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan GDNRecurrentState = Tuple[ Float[Array, "Key Value"], diff --git a/memorax/equinox/semigroups/lrnn.py b/memax/equinox/semigroups/lrnn.py similarity index 93% rename from memorax/equinox/semigroups/lrnn.py rename to memax/equinox/semigroups/lrnn.py index 24a1cf4..597c1b4 100644 --- a/memorax/equinox/semigroups/lrnn.py +++ b/memax/equinox/semigroups/lrnn.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan LinearRNNRecurrentState = Float[Array, "Hidden"] LinearRNNRecurrentStateWithReset = Tuple[LinearRNNRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/lru.py b/memax/equinox/semigroups/lru.py similarity index 96% rename from memorax/equinox/semigroups/lru.py rename to memax/equinox/semigroups/lru.py index 9acbbde..786a0b4 100644 --- a/memorax/equinox/semigroups/lru.py +++ b/memax/equinox/semigroups/lru.py @@ -6,10 +6,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Complex, Float, PRNGKeyArray, Scalar, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan LRURecurrentState = Tuple[Complex[Array, "Recurrent"], Complex[Array, "Recurrent"]] LRURecurrentStateWithReset = Tuple[LRURecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/mlp.py b/memax/equinox/semigroups/mlp.py similarity index 92% rename from memorax/equinox/semigroups/mlp.py rename to memax/equinox/semigroups/mlp.py index b369bd5..e5a8a34 100644 --- a/memorax/equinox/semigroups/mlp.py +++ b/memax/equinox/semigroups/mlp.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan MLPRecurrentState = Float[Array, "0"] # Empty array because MLP is not recurrent MLPRecurrentStateWithReset = Tuple[MLPRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/nmax.py b/memax/equinox/semigroups/nmax.py similarity index 93% rename from memorax/equinox/semigroups/nmax.py rename to memax/equinox/semigroups/nmax.py index fcd46b0..0c25d32 100644 --- a/memorax/equinox/semigroups/nmax.py +++ b/memax/equinox/semigroups/nmax.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan NMaxRecurrentState = Float[Array, "Hidden"] NMaxRecurrentStateWithReset = Tuple[NMaxRecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/s6.py b/memax/equinox/semigroups/s6.py similarity index 95% rename from memorax/equinox/semigroups/s6.py rename to memax/equinox/semigroups/s6.py index a8819c1..d208c53 100644 --- a/memorax/equinox/semigroups/s6.py +++ b/memax/equinox/semigroups/s6.py @@ -7,10 +7,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Complex, Float, PRNGKeyArray, Scalar, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan S6RecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag] diff --git a/memorax/equinox/semigroups/spherical.py b/memax/equinox/semigroups/spherical.py similarity index 94% rename from memorax/equinox/semigroups/spherical.py rename to memax/equinox/semigroups/spherical.py index 9451eba..241e6bd 100644 --- a/memorax/equinox/semigroups/spherical.py +++ b/memax/equinox/semigroups/spherical.py @@ -7,10 +7,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan RotationMatrix = Float[Array, "Hidden Hidden"] SphericalRecurrentState = RotationMatrix diff --git a/memorax/equinox/semigroups/stack.py b/memax/equinox/semigroups/stack.py similarity index 94% rename from memorax/equinox/semigroups/stack.py rename to memax/equinox/semigroups/stack.py index 14b4c7b..ca04965 100644 --- a/memorax/equinox/semigroups/stack.py +++ b/memax/equinox/semigroups/stack.py @@ -6,11 +6,11 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan -from memorax.utils import combine_and_right_align +from memax.equinox.groups import BinaryAlgebra, Semigroup, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import semigroup_scan +from memax.utils import combine_and_right_align StackRecurrentState = Tuple[Float[Array, "Stack Recurrent"], Bool[Array, "Stack"]] StackRecurrentStateWithReset = Tuple[StackRecurrentState, StartFlag] diff --git a/memax/equinox/set_actions/__init__.py b/memax/equinox/set_actions/__init__.py new file mode 100644 index 0000000..c300758 --- /dev/null +++ b/memax/equinox/set_actions/__init__.py @@ -0,0 +1,9 @@ + +"""This module contains set-action-based (classical) recurrent layers. +Each RNN type gets its own file. ++ `memax.equinox.set_actions.elman` provides a basic Elman RNN layer. ++ `memax.equinox.set_actions.lstm` provides the Long Short-Term Memory layer. ++ `memax.equinox.set_actions.gru` provides the Gated Recurrent Unit layer. ++ `memax.equinox.set_actions.mru` provides the Minimal Gated Unit layer ++ `memax.equinox.set_actions.spherical` provides a recurrent formulation of the Rotational RNN. See `memax.equinox.semigroups.spherical` for the semigroup version. +""" \ No newline at end of file diff --git a/memorax/equinox/set_actions/elman.py b/memax/equinox/set_actions/elman.py similarity index 93% rename from memorax/equinox/set_actions/elman.py rename to memax/equinox/set_actions/elman.py index bb52a57..c2f2518 100644 --- a/memorax/equinox/set_actions/elman.py +++ b/memax/equinox/set_actions/elman.py @@ -7,10 +7,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, SetAction, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import set_action_scan +from memax.equinox.groups import BinaryAlgebra, SetAction, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import set_action_scan ElmanRecurrentState = Float[Array, "Recurrent"] ElmanRecurrentStateWithReset = Tuple[ElmanRecurrentState, StartFlag] diff --git a/memorax/equinox/set_actions/gru.py b/memax/equinox/set_actions/gru.py similarity index 94% rename from memorax/equinox/set_actions/gru.py rename to memax/equinox/set_actions/gru.py index ef5e2a1..5328a19 100644 --- a/memorax/equinox/set_actions/gru.py +++ b/memax/equinox/set_actions/gru.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, SetAction, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import set_action_scan +from memax.equinox.groups import BinaryAlgebra, SetAction, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import set_action_scan GRURecurrentState = Float[Array, "Recurrent"] GRURecurrentStateWithReset = Tuple[GRURecurrentState, StartFlag] diff --git a/memorax/equinox/set_actions/lstm.py b/memax/equinox/set_actions/lstm.py similarity index 94% rename from memorax/equinox/set_actions/lstm.py rename to memax/equinox/set_actions/lstm.py index 3a7348a..6930fb4 100644 --- a/memorax/equinox/set_actions/lstm.py +++ b/memax/equinox/set_actions/lstm.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, SetAction, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, InputEmbedding, StartFlag -from memorax.equinox.scans import set_action_scan +from memax.equinox.groups import BinaryAlgebra, SetAction, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, InputEmbedding, StartFlag +from memax.equinox.scans import set_action_scan LSTMRecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] LSTMRecurrentStateWithReset = Tuple[LSTMRecurrentState, StartFlag] diff --git a/memorax/equinox/set_actions/mgu.py b/memax/equinox/set_actions/mgu.py similarity index 93% rename from memorax/equinox/set_actions/mgu.py rename to memax/equinox/set_actions/mgu.py index 38e6103..9c2ec9a 100644 --- a/memorax/equinox/set_actions/mgu.py +++ b/memax/equinox/set_actions/mgu.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, SetAction, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import set_action_scan +from memax.equinox.groups import BinaryAlgebra, SetAction, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import set_action_scan MGURecurrentState = Float[Array, "Recurrent"] MGURecurrentStateWithReset = Tuple[MGURecurrentState, StartFlag] diff --git a/memorax/equinox/set_actions/spherical.py b/memax/equinox/set_actions/spherical.py similarity index 94% rename from memorax/equinox/set_actions/spherical.py rename to memax/equinox/set_actions/spherical.py index 72ecec4..62fc1be 100644 --- a/memorax/equinox/set_actions/spherical.py +++ b/memax/equinox/set_actions/spherical.py @@ -6,10 +6,10 @@ from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.equinox.groups import BinaryAlgebra, SetAction, Module, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import set_action_scan +from memax.equinox.groups import BinaryAlgebra, SetAction, Module, Resettable +from memax.equinox.gras import GRAS +from memax.mtypes import Input, StartFlag +from memax.equinox.scans import set_action_scan SphericalRecurrentState = Float[Array, "Recurrent"] SphericalRecurrentStateWithReset = Tuple[SphericalRecurrentState, StartFlag] diff --git a/memorax/equinox/train_utils.py b/memax/equinox/train_utils.py similarity index 90% rename from memorax/equinox/train_utils.py rename to memax/equinox/train_utils.py index cff2bee..06df499 100644 --- a/memorax/equinox/train_utils.py +++ b/memax/equinox/train_utils.py @@ -10,27 +10,27 @@ import optax from jaxtyping import Array, Shaped -from memorax.equinox.groups import Module -from memorax.equinox.set_actions.elman import Elman -from memorax.equinox.set_actions.gru import GRU -from memorax.equinox.set_actions.lstm import LSTM -from memorax.equinox.set_actions.mgu import MGU -from memorax.equinox.set_actions.spherical import Spherical -from memorax.equinox.models.residual import ResidualModel -from memorax.equinox.semigroups.fwp import FWP, FWPSemigroup -from memorax.equinox.semigroups.fart import FART, FARTSemigroup -from memorax.equinox.semigroups.ffm import FFM, FFMSemigroup -from memorax.equinox.semigroups.lrnn import LinearRecurrent, LinearRNNSemigroup -from memorax.equinox.semigroups.lru import LRU, LRUSemigroup -from memorax.equinox.semigroups.nmax import NMax, NMaxSemigroup -from memorax.equinox.semigroups.spherical import PSpherical, PSphericalSemigroup -from memorax.equinox.semigroups.s6 import S6, S6Semigroup -from memorax.equinox.semigroups.delta import DeltaNet, DeltaNetSemigroup -from memorax.equinox.semigroups.deltap import DeltaProduct, DeltaProductSemigroup -from memorax.equinox.semigroups.gdn import GDN, GDNSemigroup -from memorax.equinox.semigroups.mlp import MLP -from memorax.equinox.semigroups.stack import Stack, StackSemigroup -from memorax.equinox.semigroups.attn import Attention, AttentionSemigroup +from memax.equinox.groups import Module +from memax.equinox.set_actions.elman import Elman +from memax.equinox.set_actions.gru import GRU +from memax.equinox.set_actions.lstm import LSTM +from memax.equinox.set_actions.mgu import MGU +from memax.equinox.set_actions.spherical import Spherical +from memax.equinox.models.residual import ResidualModel +from memax.equinox.semigroups.fwp import FWP, FWPSemigroup +from memax.equinox.semigroups.fart import FART, FARTSemigroup +from memax.equinox.semigroups.ffm import FFM, FFMSemigroup +from memax.equinox.semigroups.lrnn import LinearRecurrent, LinearRNNSemigroup +from memax.equinox.semigroups.lru import LRU, LRUSemigroup +from memax.equinox.semigroups.nmax import NMax, NMaxSemigroup +from memax.equinox.semigroups.spherical import PSpherical, PSphericalSemigroup +from memax.equinox.semigroups.s6 import S6, S6Semigroup +from memax.equinox.semigroups.delta import DeltaNet, DeltaNetSemigroup +from memax.equinox.semigroups.deltap import DeltaProduct, DeltaProductSemigroup +from memax.equinox.semigroups.gdn import GDN, GDNSemigroup +from memax.equinox.semigroups.mlp import MLP +from memax.equinox.semigroups.stack import Stack, StackSemigroup +from memax.equinox.semigroups.attn import Attention, AttentionSemigroup def add_batch_dim(h, batch_size: int, axis: int = 0) -> Shaped[Array, "Batch ..."]: @@ -116,7 +116,7 @@ def loss_classify_terminal_output( return the cross entropy loss between the true yn and predicted y1n. Args: - model: memorax.groups.Module + model: memax.groups.Module x: (batch, time, in_feature) y: (batch, num_classes) diff --git a/memax/linen/__init__.py b/memax/linen/__init__.py new file mode 100644 index 0000000..5341b16 --- /dev/null +++ b/memax/linen/__init__.py @@ -0,0 +1,8 @@ +"""Flax Linen backend for memax. + +`memax.linen.gras` provides the Generalized Recurrent Algebraic Structure (GRAS) base module. +`memax.linen.groups` provides the algebraic structures (semigroups, groups, etc.) used in GRAS. +`memax.linen.set_actions` provides set action-based recurrent layers (slow RNNs). +`memax.linen.semigroups` provides semigroup-based recurrent layers (fast RNNs). +`memax.linen.scans` provides scan functions for recurrent updates. +""" \ No newline at end of file diff --git a/memorax/linen/gras.py b/memax/linen/gras.py similarity index 96% rename from memorax/linen/gras.py rename to memax/linen/gras.py index d4ef1a2..d62c4a2 100644 --- a/memorax/linen/gras.py +++ b/memax/linen/gras.py @@ -4,8 +4,8 @@ import jax from jaxtyping import PRNGKeyArray, Shaped -from memorax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState -from memorax.linen.groups import BinaryAlgebra, Module +from memax.mtypes import Input, OutputEmbedding, RecurrentState, SingleRecurrentState +from memax.linen.groups import BinaryAlgebra, Module class GRAS(Module): diff --git a/memorax/linen/groups.py b/memax/linen/groups.py similarity index 98% rename from memorax/linen/groups.py rename to memax/linen/groups.py index 2a11fea..e0976a5 100644 --- a/memorax/linen/groups.py +++ b/memax/linen/groups.py @@ -6,8 +6,8 @@ import jax.numpy as jnp from jaxtyping import PRNGKeyArray, Shaped -from memorax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag -from memorax.utils import debug_shape +from memax.mtypes import Input, RecurrentState, ResetRecurrentState, StartFlag +from memax.utils import debug_shape class Module(nn.Module): diff --git a/memorax/linen/models/__init__.py b/memax/linen/models/__init__.py similarity index 100% rename from memorax/linen/models/__init__.py rename to memax/linen/models/__init__.py diff --git a/memorax/linen/models/residual.py b/memax/linen/models/residual.py similarity index 96% rename from memorax/linen/models/residual.py rename to memax/linen/models/residual.py index caf7edf..d00ee74 100644 --- a/memorax/linen/models/residual.py +++ b/memax/linen/models/residual.py @@ -4,8 +4,8 @@ import jax from jaxtyping import PRNGKeyArray, Shaped -from memorax.mtypes import Input, ResetRecurrentState -from memorax.linen.groups import Module +from memax.mtypes import Input, ResetRecurrentState +from memax.linen.groups import Module class ResidualModel(Module): diff --git a/memorax/linen/scans.py b/memax/linen/scans.py similarity index 97% rename from memorax/linen/scans.py rename to memax/linen/scans.py index b8ad680..d2b2e6d 100644 --- a/memorax/linen/scans.py +++ b/memax/linen/scans.py @@ -11,8 +11,8 @@ import jax import jax.numpy as jnp -from memorax.mtypes import RecurrentState -from memorax.utils import debug_shape +from memax.mtypes import RecurrentState +from memax.utils import debug_shape def set_action_scan( diff --git a/memorax/linen/semigroups/__init__.py b/memax/linen/semigroups/__init__.py similarity index 100% rename from memorax/linen/semigroups/__init__.py rename to memax/linen/semigroups/__init__.py diff --git a/memorax/linen/semigroups/fart.py b/memax/linen/semigroups/fart.py similarity index 94% rename from memorax/linen/semigroups/fart.py rename to memax/linen/semigroups/fart.py index 4685f9c..2b5448b 100644 --- a/memorax/linen/semigroups/fart.py +++ b/memax/linen/semigroups/fart.py @@ -6,10 +6,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.mtypes import Input, StartFlag -from memorax.linen.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.linen.gras import GRAS -from memorax.linen.scans import semigroup_scan +from memax.mtypes import Input, StartFlag +from memax.linen.groups import BinaryAlgebra, Semigroup, Resettable +from memax.linen.gras import GRAS +from memax.linen.scans import semigroup_scan FARTRecurrentState = Tuple[Float[Array, "Key Value"], Float[Array, "Key"]] FARTRecurrentStateWithReset = Tuple[FARTRecurrentState, StartFlag] diff --git a/memorax/linen/semigroups/lru.py b/memax/linen/semigroups/lru.py similarity index 97% rename from memorax/linen/semigroups/lru.py rename to memax/linen/semigroups/lru.py index 1e5477e..2de8a09 100644 --- a/memorax/linen/semigroups/lru.py +++ b/memax/linen/semigroups/lru.py @@ -8,10 +8,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Complex, Float, PRNGKeyArray, Scalar, Shaped, jaxtyped -from memorax.mtypes import Input, StartFlag -from memorax.linen.groups import Semigroup, Resettable -from memorax.linen.gras import GRAS -from memorax.linen.scans import semigroup_scan +from memax.mtypes import Input, StartFlag +from memax.linen.groups import Semigroup, Resettable +from memax.linen.gras import GRAS +from memax.linen.scans import semigroup_scan LRURecurrentState = Tuple[Complex[Array, "Recurrent"], Complex[Array, "Recurrent"]] LRURecurrentStateWithReset = Tuple[LRURecurrentState, StartFlag] diff --git a/memorax/linen/semigroups/s6.py b/memax/linen/semigroups/s6.py similarity index 95% rename from memorax/linen/semigroups/s6.py rename to memax/linen/semigroups/s6.py index a670736..cac828e 100644 --- a/memorax/linen/semigroups/s6.py +++ b/memax/linen/semigroups/s6.py @@ -8,10 +8,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.mtypes import Input, StartFlag -from memorax.linen.groups import Semigroup, Resettable -from memorax.linen.gras import GRAS -from memorax.linen.scans import semigroup_scan +from memax.mtypes import Input, StartFlag +from memax.linen.groups import Semigroup, Resettable +from memax.linen.gras import GRAS +from memax.linen.scans import semigroup_scan S6RecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag] diff --git a/memorax/linen/set_actions/__init__.py b/memax/linen/set_actions/__init__.py similarity index 58% rename from memorax/linen/set_actions/__init__.py rename to memax/linen/set_actions/__init__.py index 7f38d9f..a610395 100644 --- a/memorax/linen/set_actions/__init__.py +++ b/memax/linen/set_actions/__init__.py @@ -1,4 +1,4 @@ """This module contains set-action-based (classical) recurrent layers. Each RNN type gets its own file. -+ `memorax.linen.set_actions.gru` provides the Gated Recurrent Unit layer. ++ `memax.linen.set_actions.gru` provides the Gated Recurrent Unit layer. """ \ No newline at end of file diff --git a/memorax/linen/set_actions/gru.py b/memax/linen/set_actions/gru.py similarity index 93% rename from memorax/linen/set_actions/gru.py rename to memax/linen/set_actions/gru.py index 11a85f8..91e2c99 100644 --- a/memorax/linen/set_actions/gru.py +++ b/memax/linen/set_actions/gru.py @@ -6,10 +6,10 @@ from beartype import beartype as typechecker from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped -from memorax.mtypes import Input, StartFlag -from memorax.linen.groups import BinaryAlgebra, SetAction, Resettable -from memorax.linen.gras import GRAS -from memorax.linen.scans import set_action_scan +from memax.mtypes import Input, StartFlag +from memax.linen.groups import BinaryAlgebra, SetAction, Resettable +from memax.linen.gras import GRAS +from memax.linen.scans import set_action_scan GRURecurrentState = Float[Array, "Recurrent"] GRURecurrentStateWithReset = Tuple[GRURecurrentState, StartFlag] diff --git a/memorax/linen/train_utils.py b/memax/linen/train_utils.py similarity index 94% rename from memorax/linen/train_utils.py rename to memax/linen/train_utils.py index 0ddb93a..32f18de 100644 --- a/memorax/linen/train_utils.py +++ b/memax/linen/train_utils.py @@ -11,11 +11,11 @@ from flax.core import FrozenDict from jaxtyping import Array, Shaped -from memorax.linen.set_actions.gru import GRU -from memorax.linen.models.residual import ResidualModel -from memorax.linen.semigroups.fart import FARTSemigroup, FART -from memorax.linen.semigroups.lru import LRUSemigroup, LRU -from memorax.linen.semigroups.s6 import S6Semigroup, S6 +from memax.linen.set_actions.gru import GRU +from memax.linen.models.residual import ResidualModel +from memax.linen.semigroups.fart import FARTSemigroup, FART +from memax.linen.semigroups.lru import LRUSemigroup, LRU +from memax.linen.semigroups.s6 import S6Semigroup, S6 def add_batch_dim(h, batch_size: int, axis: int = 0) -> Shaped[Array, "Batch ..."]: diff --git a/memorax/mtypes.py b/memax/mtypes.py similarity index 100% rename from memorax/mtypes.py rename to memax/mtypes.py diff --git a/memorax/utils.py b/memax/utils.py similarity index 99% rename from memorax/utils.py rename to memax/utils.py index 954004f..1f21300 100644 --- a/memorax/utils.py +++ b/memax/utils.py @@ -1,5 +1,5 @@ """ -This module contains framework-agnostic utility functions used throughout the Memorax library. +This module contains framework-agnostic utility functions used throughout the memax library. """ from typing import Tuple diff --git a/memorax/equinox/__init__.py b/memorax/equinox/__init__.py deleted file mode 100644 index d2de016..0000000 --- a/memorax/equinox/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Equinox backend for Memorax. - -`memorax.equinox.gras` provides the Generalized Recurrent Algebraic Structure (GRAS) base module. -`memorax.equinox.groups` provides the algebraic structures (semigroups, groups, etc.) used in GRAS. -`memorax.equinox.set_actions` provides set action-based recurrent layers (slow RNNs). -`memorax.equinox.semigroups` provides semigroup-based recurrent layers (fast RNNs). -`memorax.equinox.scans` provides scan functions for recurrent updates. -""" \ No newline at end of file diff --git a/memorax/equinox/semigroups/__init__.py b/memorax/equinox/semigroups/__init__.py deleted file mode 100644 index b33c23a..0000000 --- a/memorax/equinox/semigroups/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""This module contains semigroup-based recurrent layers. -They are generally much faster than standard RNNs. - -Each RNN type gets its own file. -+ `memorax.equinox.semigroups.attn` provides dot-product attention layer. -+ `memorax.equinox.semigroups.delta` provides the DeltaNet layer. -+ `memorax.equinox.semigroups.deltap` provides the DeltaProduct layer. -+ `memorax.equinox.semigroups.stack` provides framestacking (sliding-window) as an RNN. -+ `memorax.equinox.semigroups.fart` provides the Fast AutoRegressive Transformer layer. -+ `memorax.equinox.semigroups.ffm` provides the Fast and Forgetful Memory layer. -+ `memorax.equinox.semigroups.fwp` provides the Fast Weight Programmer layer. -+ `memorax.equinox.semigroups.lru` provides the Linear Recurrent Unit layer. -+ `memorax.equinox.semigroups.gdn` provides the Gated DeltaNet layer. -+ `memorax.equinox.semigroups.lrnn` provides a basic linear recurrence. -+ `memorax.equinox.semigroups.mlp` provides an MLP (no memory) for completeness. -+ `memorax.equinox.semigroups.s6` provides the Selective State Space Model (Mamba) layer. -+ `memorax.equinox.semigroups.spherical` provides Rotational RNN layer (spherical projection). -""" \ No newline at end of file diff --git a/memorax/equinox/set_actions/__init__.py b/memorax/equinox/set_actions/__init__.py deleted file mode 100644 index 94d6e58..0000000 --- a/memorax/equinox/set_actions/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ - -"""This module contains set-action-based (classical) recurrent layers. -Each RNN type gets its own file. -+ `memorax.equinox.set_actions.elman` provides a basic Elman RNN layer. -+ `memorax.equinox.set_actions.lstm` provides the Long Short-Term Memory layer. -+ `memorax.equinox.set_actions.gru` provides the Gated Recurrent Unit layer. -+ `memorax.equinox.set_actions.mru` provides the Minimal Gated Unit layer -+ `memorax.equinox.set_actions.spherical` provides a recurrent formulation of the Rotational RNN. See `memorax.equinox.semigroups.spherical` for the semigroup version. -""" \ No newline at end of file diff --git a/memorax/linen/__init__.py b/memorax/linen/__init__.py deleted file mode 100644 index da01f63..0000000 --- a/memorax/linen/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Flax Linen backend for Memorax. - -`memorax.linen.gras` provides the Generalized Recurrent Algebraic Structure (GRAS) base module. -`memorax.linen.groups` provides the algebraic structures (semigroups, groups, etc.) used in GRAS. -`memorax.linen.set_actions` provides set action-based recurrent layers (slow RNNs). -`memorax.linen.semigroups` provides semigroup-based recurrent layers (fast RNNs). -`memorax.linen.scans` provides scan functions for recurrent updates. -""" \ No newline at end of file diff --git a/run_equinox_experiments.py b/run_equinox_experiments.py index e976802..3d43d5c 100644 --- a/run_equinox_experiments.py +++ b/run_equinox_experiments.py @@ -1,6 +1,6 @@ """This script runs experiments training various recurrent memory models on different datasets using Equinox. It serves as a reference implementation -for training and evaluating memorax modules.""" +for training and evaluating memax modules.""" import argparse @@ -11,10 +11,10 @@ import tqdm import wandb -from memorax.datasets.mnist_math import get_dataset as get_mnist_math -from memorax.datasets.sequential_mnist import get_dataset as get_sequential_mnist -from memorax.datasets.continuous_localization import get_rot_dataset, get_trans_dataset -from memorax.equinox.train_utils import ( +from memax.datasets.mnist_math import get_dataset as get_mnist_math +from memax.datasets.sequential_mnist import get_dataset as get_sequential_mnist +from memax.datasets.continuous_localization import get_rot_dataset, get_trans_dataset +from memax.equinox.train_utils import ( get_residual_memory_models, loss_classify_terminal_output, loss_regress_terminal_output, @@ -41,7 +41,7 @@ def parse_args(): parser.add_argument( "--project-name", type=str, - default="memorax-debug", + default="memax-debug", help="Weights & Biases project name", ) parser.add_argument( diff --git a/run_linen_experiments.py b/run_linen_experiments.py index 74be348..edf8ab6 100644 --- a/run_linen_experiments.py +++ b/run_linen_experiments.py @@ -1,6 +1,6 @@ """This script runs experiments training various recurrent memory models on different datasets using Flax Linen. It serves as a reference implementation -for training and evaluating memorax modules.""" +for training and evaluating memax modules.""" import argparse from functools import partial @@ -11,9 +11,9 @@ import tqdm import wandb -from memorax.datasets.mnist_math import get_dataset as get_mnist_math -from memorax.datasets.sequential_mnist import get_dataset as get_sequential_mnist -from memorax.linen.train_utils import ( +from memax.datasets.mnist_math import get_dataset as get_mnist_math +from memax.datasets.sequential_mnist import get_dataset as get_sequential_mnist +from memax.linen.train_utils import ( get_residual_memory_models, loss_classify_terminal_output, update_model, @@ -39,7 +39,7 @@ def parse_args(): parser.add_argument( "--project-name", type=str, - default="memorax-debug", + default="memax-debug", help="Weights & Biases project name", ) parser.add_argument( diff --git a/setup.py b/setup.py index 9154857..484b6c4 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,14 @@ from setuptools import find_packages, setup setup( - name="memorax", + name="memax", version="0.1.0", author="Steven Morad", author_email="stevenmorad@gmail.com", description="Deep memory and sequence modeling in JAX", long_description=open("README.md").read(), long_description_content_type="text/markdown", - url="https://github.com/smorad/memorax", + url="https://github.com/smorad/memax", packages=find_packages(), install_requires=[ "jax", diff --git a/tests/test_associative_equinox.py b/tests/test_associative_equinox.py index e016b35..962cac1 100644 --- a/tests/test_associative_equinox.py +++ b/tests/test_associative_equinox.py @@ -5,8 +5,8 @@ import jax import jax.numpy as jnp -from memorax.equinox.groups import Semigroup -from memorax.equinox.train_utils import get_semigroups +from memax.equinox.groups import Semigroup +from memax.equinox.train_utils import get_semigroups def random_state(state, key): diff --git a/tests/test_associative_linen.py b/tests/test_associative_linen.py index 92523f9..1d0fb9d 100644 --- a/tests/test_associative_linen.py +++ b/tests/test_associative_linen.py @@ -5,8 +5,8 @@ import jax import jax.numpy as jnp -from memorax.linen.groups import Semigroup -from memorax.linen.train_utils import get_semigroups +from memax.linen.groups import Semigroup +from memax.linen.train_utils import get_semigroups def random_state(state, key): diff --git a/tests/test_continuous_localization.py b/tests/test_continuous_localization.py index e5d9639..cfdf739 100644 --- a/tests/test_continuous_localization.py +++ b/tests/test_continuous_localization.py @@ -1,4 +1,4 @@ -from memorax.datasets.continuous_localization import step +from memax.datasets.continuous_localization import step import jax.numpy as jnp from jax.scipy.spatial.transform import Rotation import jax diff --git a/tests/test_initial_input_equinox.py b/tests/test_initial_input_equinox.py index 61824a4..e8d4e19 100644 --- a/tests/test_initial_input_equinox.py +++ b/tests/test_initial_input_equinox.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import optax -from memorax.equinox.train_utils import get_residual_memory_models +from memax.equinox.train_utils import get_residual_memory_models def get_desired_accuracies(): diff --git a/tests/test_initial_input_linen.py b/tests/test_initial_input_linen.py index 0a8ed0b..7e00390 100644 --- a/tests/test_initial_input_linen.py +++ b/tests/test_initial_input_linen.py @@ -5,7 +5,7 @@ import optax from functools import partial -from memorax.linen.train_utils import get_residual_memory_models +from memax.linen.train_utils import get_residual_memory_models def get_desired_accuracies(): diff --git a/tests/test_readme.py b/tests/test_readme.py index fd90407..b3ba6f5 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -3,10 +3,10 @@ def test_readme(): import jax import jax.numpy as jnp - from memorax.equinox.set_actions.gru import GRU - from memorax.equinox.models.residual import ResidualModel - from memorax.equinox.semigroups.lru import LRU, LRUSemigroup - from memorax.utils import debug_shape + from memax.equinox.set_actions.gru import GRU + from memax.equinox.models.residual import ResidualModel + from memax.equinox.semigroups.lru import LRU, LRUSemigroup + from memax.utils import debug_shape # You can pack multiple subsequences into a single sequence using the start flag sequence_starts = jnp.array([True, False, False, True, False]) @@ -73,11 +73,11 @@ def test_readme(): h, y = eqx.filter_jit(model)(latest_h, inputs) def test_readme_quickstart(): - from memorax.equinox.train_utils import get_residual_memory_models + from memax.equinox.train_utils import get_residual_memory_models import jax import jax.numpy as jnp from equinox import filter_jit, filter_vmap - from memorax.equinox.train_utils import add_batch_dim + from memax.equinox.train_utils import add_batch_dim T, F = 5, 6 # time and feature dim diff --git a/tests/test_reset_equinox.py b/tests/test_reset_equinox.py index 47af474..9dabec6 100644 --- a/tests/test_reset_equinox.py +++ b/tests/test_reset_equinox.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import pytest -from memorax.equinox.train_utils import add_batch_dim, get_residual_memory_models +from memax.equinox.train_utils import add_batch_dim, get_residual_memory_models @pytest.mark.parametrize("name, model", get_residual_memory_models( input=1, hidden=8, output=10, num_layers=2, key=jax.random.key(0) diff --git a/tests/test_reset_linen.py b/tests/test_reset_linen.py index fbffc0e..68d10f5 100644 --- a/tests/test_reset_linen.py +++ b/tests/test_reset_linen.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from memorax.linen.train_utils import add_batch_dim, get_residual_memory_models +from memax.linen.train_utils import add_batch_dim, get_residual_memory_models @pytest.mark.parametrize("name, model", get_residual_memory_models( hidden=8, output=10, num_layers=2 diff --git a/tests/test_stack_equinox.py b/tests/test_stack_equinox.py index 7a4ede8..40ace0f 100644 --- a/tests/test_stack_equinox.py +++ b/tests/test_stack_equinox.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import pytest -from memorax.equinox.semigroups.stack import Stack +from memax.equinox.semigroups.stack import Stack def test_stack(): From 699668a31901f21e2a027c278db3ac22d6fb0486 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Wed, 3 Dec 2025 18:27:27 +0800 Subject: [PATCH 2/2] Factorize install deps --- README.md | 9 +++++++-- setup.py | 24 ++++++++++++++++++------ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 6aace0a..d5b8e6a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# memax - Sequence and Memory Modeling in JAX +# Memax - Sequence and Memory Modeling in JAX [![Tests](https://github.com/smorad/memax/actions/workflows/python_app.yaml/badge.svg)](https://github.com/smorad/memax/actions/workflows/python_app.yaml) -memax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding. +Memax is a library for efficient recurrent models. Using category theory, we utilize a [simple interface](memax/equinox/groups.py) that should work for nearly all recurrent models. We provide a unified interface for fast recurrent state resets across the sequence, allowing you to train over batches of variable-length sequences without sequence truncation or zero-padding. ## Table of Contents 1. [Models](#recurrent-models) @@ -58,6 +58,11 @@ Install `memax` using pip and git for your specific framework pip install "memax[equinox]@git+https://github.com/smorad/memax" pip install "memax[flax]@git+https://github.com/smorad/memax" ``` +If you want to use our dataset and training scripts, install via +```bash +pip install "memax[dataset,equinox]@git+https://github.com/smorad/memax" +pip install "memax[dataset,flax]@git+https://github.com/smorad/memax" +``` ## Equinox Quickstart ```python diff --git a/setup.py b/setup.py index 484b6c4..a64387a 100644 --- a/setup.py +++ b/setup.py @@ -16,17 +16,29 @@ "jaxtyping", "optax", "beartype", - "tqdm", - "datasets", - "pillow", - "wandb", ], extras_require={ 'equinox': ['equinox'], - 'flax': ['flax'], + # TODO: Update if flax fixes their shit + 'flax': [ + 'flax', + 'please-downgrade-to-python-3.13-for-flax; python_version >= "3.14"', + ], + 'train': [ + 'datasets', + 'tqdm', + 'pillow', + 'wandb', + ], 'all': [ 'equinox', - 'flax' + 'flax', + 'please-downgrade-to-python-3.13-for-flax; python_version >= "3.14"', + # train + 'datasets', + 'tqdm', + 'pillow', + 'wandb', ], 'test': [ 'pytest',