Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ We implement both linear and log-complexity recurrent models.
|------|--------------------------|-------|------|
| Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memorax/equinox/semigroups/lru.py) |
| Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6.py) |
| Diagonal Selective State Space Model (S6D) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6d.py) |
| Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memorax/equinox/semigroups/lrnn.py) |
| Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memorax/equinox/semigroups/fart.py) |
| Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memorax/equinox/semigroups/ffm.py) |
Expand Down
1 change: 0 additions & 1 deletion memorax/equinox/semigroups/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@
+ `memorax.equinox.semigroups.lrnn` provides a basic linear recurrence.
+ `memorax.equinox.semigroups.mlp` provides an MLP (no memory) for completeness.
+ `memorax.equinox.semigroups.s6` provides the Selective State Space Model (Mamba) layer.
+ `memorax.equinox.semigroups.s6d` provides a diagonal variant of the Selective State Space Model (Mamba) layer.
+ `memorax.equinox.semigroups.spherical` provides Rotational RNN layer (spherical projection).
"""
27 changes: 9 additions & 18 deletions memorax/equinox/semigroups/s6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,10 @@
S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag]


# Inits
@jaxtyped(typechecker=typechecker)
def glorot_init(
key: PRNGKeyArray, shape: Tuple[int, ...], normalization: Scalar = jnp.array(1.0)
):
return jax.random.normal(key=key, shape=shape) / normalization


class S6Semigroup(Semigroup):
"""The full-rank S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752.
"""The S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752.

This is a S5/LRU recurrent update with a learnable timestep parameter. """
This is an S5/LRU recurrent update with a learnable timestep parameter. """

recurrent_size: int

Expand Down Expand Up @@ -58,8 +50,7 @@ def __call__(

class S6(GRAS):
"""
The full-rank S6 SSM, an SSM with a trainable dt.
The recurrent matrix A is diagonal, but the B, C matrices are full-rank.
The S6 SSM with diagonal recurrence. We base this on the LRU, and add a trainable dt.

You might want to use this as a building block for a more complex model.
"""
Expand Down Expand Up @@ -93,8 +84,8 @@ def __init__(self, recurrent_size, hidden_size, key):
self.scan = semigroup_scan

self.A_log = jax.random.normal(keys[0], (self.recurrent_size,))
self.B = nn.Linear(self.hidden_size, self.recurrent_size * self.recurrent_size, key=keys[1], use_bias=False)
self.C = nn.Linear(self.recurrent_size, self.hidden_size, key=keys[2], use_bias=False)
self.B = nn.Linear(self.hidden_size, self.recurrent_size, key=keys[1])
self.C = nn.Linear(self.recurrent_size, self.hidden_size, key=keys[2])
self.dt = nn.Sequential([
nn.Linear(self.hidden_size, self.recurrent_size, key=keys[3]),
nn.Lambda(jax.nn.softplus)
Expand All @@ -106,12 +97,12 @@ def forward_map(self, x: Input, key: Optional[Shaped[PRNGKeyArray, ""]] = None):
dt = self.dt(emb)
A = -jnp.exp(self.A_log)
A_bar = jnp.exp(dt * A)
B = self.B(emb).reshape(self.recurrent_size, self.recurrent_size)
# NOTE: A is diagonal so we can compute B_bar more simply than the mamba paper
B = self.B(emb)
# NOTE: A and B are diagonal so we can compute B_bar more simply than the mamba paper
# Thankfully, inv(A) is just 1 / A if A is diagonal
# Furthermore the dt's cancel: 1 / (dt A) with dt B
B_bar = jnp.diag(1 / A * (A_bar - 1.0)) @ B
Bu = B_bar @ emb
B_bar = 1 / A * (A_bar - 1.0) * B
Bu = B_bar * emb
return (A_bar, Bu), start

@jaxtyped(typechecker=typechecker)
Expand Down
131 changes: 0 additions & 131 deletions memorax/equinox/semigroups/s6d.py

This file was deleted.

56 changes: 29 additions & 27 deletions memorax/equinox/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
It includes loss functions, accuracy metrics, and training loops.
It also provides a straightforward way to construct multi-layer recurrent models."""

from beartype.typing import Callable, Dict, Tuple
from beartype.typing import Callable, Dict, Tuple, Any, Optional

import equinox as eqx
import jax
Expand All @@ -25,7 +25,6 @@
from memorax.equinox.semigroups.nmax import NMax, NMaxSemigroup
from memorax.equinox.semigroups.spherical import PSpherical, PSphericalSemigroup
from memorax.equinox.semigroups.s6 import S6, S6Semigroup
from memorax.equinox.semigroups.s6d import S6D, S6DSemigroup
from memorax.equinox.semigroups.delta import DeltaNet, DeltaNetSemigroup
from memorax.equinox.semigroups.deltap import DeltaProduct, DeltaProductSemigroup
from memorax.equinox.semigroups.gdn import GDN, GDNSemigroup
Expand Down Expand Up @@ -226,7 +225,6 @@ def get_semigroups(
"LinearRNN": LinearRNNSemigroup(recurrent_size),
"LRU": LRUSemigroup(recurrent_size),
"S6": S6Semigroup(recurrent_size),
"S6D": S6DSemigroup(recurrent_size),
"NMax": NMaxSemigroup(recurrent_size),
"FWP": FWPSemigroup(recurrent_size),
"DeltaNet": DeltaNetSemigroup(recurrent_size),
Expand All @@ -244,81 +242,83 @@ def get_residual_memory_models(
models: str = "all",
*,
key: jax.random.PRNGKey,
layer_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict] = None,
) -> Dict[str, Module]:
"""Returns a dictionary of models, correponding to all semigroups and set actions.

This returns a dictionary of models, each consisting of multiple recurrent cells
with residual and DenseNet connections between them.
"""
layer_kwargs = layer_kwargs or {}
model_kwargs = model_kwargs or {}
layers = {
# for debug
"MLP": lambda recurrent_size, key: MLP(
recurrent_size=recurrent_size, key=key
recurrent_size=recurrent_size, key=key, **layer_kwargs.get("MLP", {})
),
# semigroups
"NMax": lambda recurrent_size, key: NMax(
recurrent_size=recurrent_size, key=key
),
"FART": lambda recurrent_size, key: FART(
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("FART", {})
),
"FWP": lambda recurrent_size, key: FWP(
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("FWP", {})
),
"DeltaNet": lambda recurrent_size, key: DeltaNet(
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("DeltaNet", {})
),
"DeltaProduct": lambda recurrent_size, key: DeltaProduct(
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), rank=4, key=key
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), rank=4, key=key, **layer_kwargs.get("DeltaProduct", {})
),
"GDN": lambda recurrent_size, key: GDN(
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key
hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("GDN", {})
),
"FFM": lambda recurrent_size, key: FFM(
hidden_size=recurrent_size, context_size=recurrent_size//4, trace_size=4, key=key
),
"S6D": lambda recurrent_size, key: S6D(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
hidden_size=recurrent_size, context_size=recurrent_size//4, trace_size=4, key=key, **layer_kwargs.get("FFM", {})
),
"S6": lambda recurrent_size, key: S6(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("S6", {})
),
"PSpherical": lambda recurrent_size, key: PSpherical(
recurrent_size=round(recurrent_size ** 0.5),
hidden_size=recurrent_size,
key=key
key=key,
**layer_kwargs.get("PSpherical", {})
),
"LRU": lambda recurrent_size, key: LRU(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LRU", {})
),
"LinearRNN": lambda recurrent_size, key: LinearRecurrent(
recurrent_size=recurrent_size, key=key
recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LinearRNN", {})
),
"Stack": lambda recurrent_size, key: Stack(
recurrent_size=recurrent_size, stack_size=4, key=key
recurrent_size=recurrent_size, stack_size=4, key=key, **layer_kwargs.get("Stack", {})
),
"Attention": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key
recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key, **layer_kwargs.get("Attention", {})
),
"Attention-RoPE": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key
recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key, **layer_kwargs.get("Attention-RoPE", {})
),
"Attention-ALiBi": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key
recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key, **layer_kwargs.get("Attention-ALiBi", {})
),
# set actions
"GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key),
"GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("GRU", {})),
"Elman": lambda recurrent_size, key: Elman(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("Elman", {})
),
"ElmanReLU": lambda recurrent_size, key: Elman(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, activation=jax.nn.relu,
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, activation=jax.nn.relu, **layer_kwargs.get("ElmanReLU", {})
),
"Spherical": lambda recurrent_size, key: Spherical(
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key
hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("Spherical", {})
),
"MGU": lambda recurrent_size, key: MGU(recurrent_size=recurrent_size, key=key),
"LSTM": lambda recurrent_size, key: LSTM(recurrent_size=recurrent_size, key=key),
"MGU": lambda recurrent_size, key: MGU(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("MGU", {})),
"LSTM": lambda recurrent_size, key: LSTM(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LSTM", {})),
}
if models == "all":
return {
Expand All @@ -329,6 +329,7 @@ def get_residual_memory_models(
output_size=output,
num_layers=num_layers,
key=key,
**model_kwargs,
)
for name, fn in layers.items()
}
Expand All @@ -341,6 +342,7 @@ def get_residual_memory_models(
output_size=output,
num_layers=num_layers,
key=key,
**model_kwargs,
)
for name in models
}
Loading