Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ We implement both linear and log-complexity recurrent models.
| DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memorax/equinox/semigroups/delta.py) |
| Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memorax/equinox/semigroups/gdn.py) |
| DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memorax/equinox/semigroups/deltap.py) |
| Dot Product Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) |
| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) |
| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memorax/equinox/semigroups/attn.py) |
| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memorax/equinox/semigroups/attn.py) |
| Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memorax/equinox/set_actions/elman.py) |
| Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memorax/equinox/set_actions/gru.py) |
| Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memorax/equinox/set_actions/mgu.py) |
Expand Down
49 changes: 35 additions & 14 deletions memorax/equinox/semigroups/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import jax.numpy as jnp
from beartype import beartype as typechecker
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool, Int

from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable
from memorax.equinox.gras import GRAS
from memorax.mtypes import Input, StartFlag
from memorax.equinox.scans import semigroup_scan
from memorax.utils import combine_and_right_align
from memorax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align

AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Attention"]]
AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Window"], Int[Array, "Window"]]
AttentionRecurrentStateWithReset = Tuple[AttentionRecurrentState, StartFlag]


Expand All @@ -36,7 +36,8 @@ def initialize_carry(
value = jnp.zeros((self.window_size, self.recurrent_size))
# Valid (non-pad) mask
mask = jnp.zeros((self.window_size,), dtype=bool)
return (key, value, mask)
ts = jnp.zeros((self.window_size), dtype=jnp.int32)
return (key, value, mask, ts)

@jaxtyped(typechecker=typechecker)
def __call__(
Expand All @@ -56,11 +57,12 @@ def __call__(
# mask = jnp.concatenate([mleft, mright])[-window_size:]

# So we use a tricky function instead
ckey, cvalue, cmask = carry
key, value, mask = input
ckey, cvalue, cmask, cts = carry
key, value, mask, ts = input
out_key, out_mask = combine_and_right_align(ckey, cmask, key, mask)
out_value, _ = combine_and_right_align(cvalue, cmask, value, mask)
return (out_key, out_value, out_mask)
out_ts = cts + ts
return (out_key, out_value, out_mask, out_ts)


class Attention(GRAS):
Expand All @@ -73,6 +75,7 @@ class Attention(GRAS):
V: nn.Linear
recurrent_size: int
window_size: int
positional_embedding: Optional[str]
scan: Callable[
[
Callable[
Expand All @@ -87,9 +90,17 @@ class Attention(GRAS):
algebra: BinaryAlgebra


def __init__(self, recurrent_size, window_size, key):
def __init__(self, recurrent_size: int, window_size: int, positional_embedding: Optional[str], key):
"""Standard dot-product attention with a sliding window.
Arguments:
recurrent_size: The size of the attention embeddings.
window_size: The size of the attention window (context length).
rope: Whether to use RoPE embeddings (False means no embeddings).
"""
assert positional_embedding in [None, "rope", "alibi"], "positional_embedding must be one of None, 'rope', or 'alibi'"
self.recurrent_size = recurrent_size
self.window_size = window_size
self.positional_embedding = positional_embedding
self.algebra = Resettable(AttentionSemigroup(recurrent_size, window_size=window_size))
self.scan = semigroup_scan
keys = jax.random.split(key, 5)
Expand All @@ -103,7 +114,6 @@ def forward_map(
) -> AttentionRecurrentStateWithReset:
emb, start = x
# Add Attention dim for concat
#return emb.reshape(1, -1), start
mask = jnp.concatenate([
jnp.zeros((self.window_size - 1), dtype=bool),
jnp.ones((1,), dtype=bool)
Expand All @@ -118,7 +128,8 @@ def forward_map(
jnp.zeros((self.window_size - 1, *emb.shape), dtype=emb.dtype),
v.reshape(1, -1)
])
return (key, value, mask), start
ts = jnp.ones((self.window_size), dtype=jnp.int32)
return (key, value, mask, ts), start

@jaxtyped(typechecker=typechecker)
def backward_map(
Expand All @@ -129,7 +140,8 @@ def backward_map(
) -> Float[Array, "{self.recurrent_size}"]:
emb, start = x
state, reset_carry = h
K, V, mask = state
K, V, mask, ts = state
q = self.Q(emb)

# B = batch size
# S = length of the key/value (source)
Expand All @@ -139,11 +151,20 @@ def backward_map(
# K = number of key/value heads
# G = number of groups, which equals to N // K
n, k, t, s, h = 1, 1, 1, self.window_size, self.recurrent_size
bias = None
if self.positional_embedding == "alibi":
m = 2 ** -8
# T-1 to 0
bias = m * (ts[0] + jnp.arange(-s + 1, 1))
elif self.positional_embedding == "rope":
K, q = apply_rope(K, q)

mask = mask.reshape(n, t, s)
bias = bias if bias is None else bias.reshape(n, t, s)
K = K.reshape(s, k, h)
q = self.Q(emb).reshape(t, n, h) # Only for current timestep
q = q.reshape(t, n, h) # Only for current timestep
V = V.reshape(s, k, h)
mask = mask.reshape(n, t, s)
z = jax.nn.dot_product_attention(q, K, V, mask=mask)
z = jax.nn.dot_product_attention(q, K, V, mask=mask, bias=bias)
return z.reshape(-1)

@jaxtyped(typechecker=typechecker)
Expand Down
8 changes: 7 additions & 1 deletion memorax/equinox/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,13 @@ def get_residual_memory_models(
recurrent_size=recurrent_size, stack_size=4, key=key
),
"Attention": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, key=key
recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key
),
"Attention-RoPE": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key
),
"Attention-ALiBi": lambda recurrent_size, key: Attention(
recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key
),
# set actions
"GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key),
Expand Down
129 changes: 111 additions & 18 deletions memorax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Tuple
import jax
import jax.numpy as jnp
from jaxtyping import Array, Int, Shaped
from jaxtyping import Array, Int, Shaped, Float


def debug_shape(x: jax.Array) -> str:
Expand All @@ -15,27 +15,120 @@ def debug_shape(x: jax.Array) -> str:

return eqx.tree_pprint(jax.tree.map(lambda x: {x.shape: x.dtype}, x))


def transformer_positional_encoding(
d_model: int, time_index: Int[Array, ""]
) -> jnp.ndarray:
def apply_rope(keys: Float[Array, "Time Feat"], query: Float[Array, "Feat"]) -> Tuple[Float[Array, "Time Feat"], Float[Array, "Feat"]]:
"""
Generate a positional encoding vector for a given time index.

Applies RoPE assuming contiguous time indices.

Constraints:
- Keys correspond to time steps [1, 2, ..., T]
- Query corresponds to time step T

Args:
time_index (int): The time step index to encode.
d_model (int): The dimensionality of the encoding vector.

keys: Array of shape (T, F)
query: Array of shape (F)
Returns:
jnp.ndarray: A positional encoding vector of shape (d_model,).
keys_rope: Embedded keys (T, F)
query_rope: Embedded query (F)
"""
position = time_index
div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0) / d_model))
pos_encoding = jnp.zeros(d_model)
pos_encoding = pos_encoding.at[0::2].set(jnp.sin(position * div_term))
pos_encoding = pos_encoding.at[1::2].set(jnp.cos(position * div_term))
return pos_encoding

T, F = keys.shape
assert F % 2 == 0, "Feature dimension must be even"

# 1. Generate the Position Indices based on shape T
# Keys are positions 1 to T
key_indices = jnp.arange(1, T + 1, dtype=jnp.float32)
# Query is position T
query_index = jnp.array(T, dtype=jnp.float32)

# 2. Calculate RoPE Frequencies (Theta)
# Standard formula: theta_i = 10000^(-2i/d)
theta_indices = jnp.arange(0, F, 2)
theta = 1.0 / (10000.0 ** (theta_indices / F)) # Shape: (F/2,)

# 3. Create Complex Rotation Angles
# Keys: (T, F/2) -> broadcast positions against frequencies
key_angles = jnp.outer(key_indices, theta)
# Query: (F/2,) -> scalar T against frequencies
query_angle = query_index * theta

# Calculate rotation vectors: e^(i * angle)
key_rotators = jnp.exp(1j * key_angles)
query_rotator = jnp.exp(1j * query_angle)

# 4. Apply Rotation using Complex Numbers
# Reshape (T, F) -> (T, F/2, 2) and convert to complex
keys_complex = keys.reshape(T, -1, 2)
keys_complex = keys_complex[..., 0] + 1j * keys_complex[..., 1]

# Reshape (F) -> (F/2, 2) and convert to complex
query_complex = query.reshape(-1, 2)
query_complex = query_complex[..., 0] + 1j * query_complex[..., 1]

# Multiply (rotate)
keys_out_complex = keys_complex * key_rotators
query_out_complex = query_complex * query_rotator

# 5. Convert back to Real
keys_rope = jnp.stack([keys_out_complex.real, keys_out_complex.imag], axis=-1).reshape(T, F)
query_rope = jnp.stack([query_out_complex.real, query_out_complex.imag], axis=-1).reshape(F)

return keys_rope, query_rope

def apply_sinusoidal_pe(keys: Float[Array, "Time Feat"], query: Float[Array, "Feat"], offset: Int[Array, ""] = jnp.array(0)):
"""
Applies Standard Sinusoidal Positional Encoding with a temporal offset.

Args:
keys: Array of shape (T, F).
query: Array of shape (F).
offset: (int or scalar) The starting time offset.
If offset=10, keys map to positions 11...10+T.

Returns:
keys_pe: keys + PE(pos)
query_pe: query + PE(pos)
"""
T, F = keys.shape
# Don't allow python ints which force recompile
assert isinstance(offset, jax.Array), "Offset must be a JAX array scalar."
assert F % 2 == 0, "Feature dimension F must be even."

# 1. Define Positions with Offset
# Cast to float32 immediately for instruction efficiency in sin/cos later
offset_arr = jnp.array(offset, dtype=jnp.float32)

# Keys: [1+offset, 2+offset, ..., T+offset]
key_positions = jnp.arange(1, T + 1, dtype=jnp.float32) + offset_arr
key_positions = key_positions[:, None] # Shape (T, 1) for broadcasting

# Query: T + offset
# corresponds to the last time step in this batch
query_position = (jnp.array(T, dtype=jnp.float32) + offset_arr)

# 2. Calculate Frequency Divisor
# Note: Standard simplified implementation is just F.
# The exact Vaswani paper uses exp(arange(0, d, 2) * -(log(10000.0) / d))

dim_indices = jnp.arange(0, F, 2, dtype=jnp.float32)
div_term = jnp.exp(dim_indices * -(jnp.log(10000.0) / F)) # Shape (F/2,)

# 3. Generate Embeddings for Keys
# Broadcast (T, 1) * (F/2,) -> (T, F/2)
k_args = key_positions * div_term

# Interleave Sin/Cos for keys
# Shape: (T, F/2, 2) -> Flatten to (T, F)
pe_keys = jnp.stack([jnp.sin(k_args), jnp.cos(k_args)], axis=-1).reshape(T, F)

# 4. Generate Embeddings for Query
# Broadcast Scalar * (F/2,) -> (F/2,)
q_args = query_position * div_term

# Interleave Sin/Cos for query
pe_query = jnp.stack([jnp.sin(q_args), jnp.cos(q_args)], axis=-1).reshape(F)

# 5. Add to Inputs
return keys + pe_keys, query + pe_query

def combine_and_right_align(
left_array: Shaped[Array, "Time Feat"],
Expand Down
42 changes: 22 additions & 20 deletions tests/test_initial_input_equinox.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,29 @@ def get_desired_accuracies():
return {
"MLP": 0,
"Stack": 0,
"Attention": 0.999,
"DLSE": 0.999,
"FFM": 0.999,
"FART": 0.999,
"FWP": 0.999,
"DeltaNet": 0.999,
"DeltaProduct": 0.999,
"GDN": 0.999,
"LRU": 0.999,
"S6": 0.999,
"LinearRNN": 0.999,
"PSpherical": 0.999,
"GRU": 0.999,
"Attention": 0.99,
"Attention-RoPE": 0.99,
"Attention-ALiBi": 0.99,
"DLSE": 0.99,
"FFM": 0.99,
"FART": 0.99,
"FWP": 0.99,
"DeltaNet": 0.99,
"DeltaProduct": 0.99,
"GDN": 0.99,
"LRU": 0.99,
"S6": 0.99,
"LinearRNN": 0.99,
"PSpherical": 0.99,
"GRU": 0.99,
"Elman": 0.55,
"ElmanReLU": 0.55,
"Spherical": 0.999,
"NMax": 0.999,
"MGU": 0.999,
"LSTM": 0.999,
"S6D": 0.999,
"S6": 0.999,
"Spherical": 0.99,
"NMax": 0.99,
"MGU": 0.99,
"LSTM": 0.99,
"S6D": 0.99,
"S6": 0.99,
}


Expand All @@ -43,7 +45,7 @@ def ce_loss(y_hat, y):
4, 8, 4 - 1, key=jax.random.key(0),
).items())
def test_initial_input(
model_name, model, epochs=4000, num_seqs=5, seq_len=20, input_dims=4
model_name, model, epochs=2000, num_seqs=5, seq_len=20, input_dims=4
):
timesteps = num_seqs * seq_len
seq_idx = jnp.array([seq_len * i for i in range(num_seqs)])
Expand Down