diff --git a/README.md b/README.md index 752c5af..34099f7 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,9 @@ We implement both linear and log-complexity recurrent models. | DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2406.06484) | [[code]](memorax/equinox/semigroups/delta.py) | | Gated DeltaNet | $O(\log{n})$ | [[paper]](https://arxiv.org/pdf/2412.06464) | [[code]](memorax/equinox/semigroups/gdn.py) | | DeltaProduct | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2502.10297) | [[code]](memorax/equinox/semigroups/deltap.py) | -| Dot Product Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) | +| Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1706.03762) | [[code]](memorax/equinox/semigroups/attn.py) | +| RoPE-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2104.09864) | [[code]](memorax/equinox/semigroups/attn.py) | +| ALiBi-Attention | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2108.12409) | [[code]](memorax/equinox/semigroups/attn.py) | | Elman Network | $O(n)$ | [[paper]](https://www.sciencedirect.com/science/article/pii/036402139090002E) | [[code]](memorax/equinox/set_actions/elman.py) | | Gated Recurrent Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1412.3555) | [[code]](memorax/equinox/set_actions/gru.py) | | Minimal Gated Unit | $O(n)$ | [[paper]](https://arxiv.org/abs/1603.09420) | [[code]](memorax/equinox/set_actions/mgu.py) | diff --git a/memorax/equinox/semigroups/attn.py b/memorax/equinox/semigroups/attn.py index 921afa3..3971abb 100644 --- a/memorax/equinox/semigroups/attn.py +++ b/memorax/equinox/semigroups/attn.py @@ -4,15 +4,15 @@ import jax.numpy as jnp from beartype import beartype as typechecker from equinox import nn -from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool +from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped, Bool, Int from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable from memorax.equinox.gras import GRAS from memorax.mtypes import Input, StartFlag from memorax.equinox.scans import semigroup_scan -from memorax.utils import combine_and_right_align +from memorax.utils import apply_rope, apply_sinusoidal_pe, combine_and_right_align -AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Attention"]] +AttentionRecurrentState = Tuple[Float[Array, "Window Recurrent"], Float[Array, "Window Recurrent"], Bool[Array, "Window"], Int[Array, "Window"]] AttentionRecurrentStateWithReset = Tuple[AttentionRecurrentState, StartFlag] @@ -36,7 +36,8 @@ def initialize_carry( value = jnp.zeros((self.window_size, self.recurrent_size)) # Valid (non-pad) mask mask = jnp.zeros((self.window_size,), dtype=bool) - return (key, value, mask) + ts = jnp.zeros((self.window_size), dtype=jnp.int32) + return (key, value, mask, ts) @jaxtyped(typechecker=typechecker) def __call__( @@ -56,11 +57,12 @@ def __call__( # mask = jnp.concatenate([mleft, mright])[-window_size:] # So we use a tricky function instead - ckey, cvalue, cmask = carry - key, value, mask = input + ckey, cvalue, cmask, cts = carry + key, value, mask, ts = input out_key, out_mask = combine_and_right_align(ckey, cmask, key, mask) out_value, _ = combine_and_right_align(cvalue, cmask, value, mask) - return (out_key, out_value, out_mask) + out_ts = cts + ts + return (out_key, out_value, out_mask, out_ts) class Attention(GRAS): @@ -73,6 +75,7 @@ class Attention(GRAS): V: nn.Linear recurrent_size: int window_size: int + positional_embedding: Optional[str] scan: Callable[ [ Callable[ @@ -87,9 +90,17 @@ class Attention(GRAS): algebra: BinaryAlgebra - def __init__(self, recurrent_size, window_size, key): + def __init__(self, recurrent_size: int, window_size: int, positional_embedding: Optional[str], key): + """Standard dot-product attention with a sliding window. + Arguments: + recurrent_size: The size of the attention embeddings. + window_size: The size of the attention window (context length). + rope: Whether to use RoPE embeddings (False means no embeddings). + """ + assert positional_embedding in [None, "rope", "alibi"], "positional_embedding must be one of None, 'rope', or 'alibi'" self.recurrent_size = recurrent_size self.window_size = window_size + self.positional_embedding = positional_embedding self.algebra = Resettable(AttentionSemigroup(recurrent_size, window_size=window_size)) self.scan = semigroup_scan keys = jax.random.split(key, 5) @@ -103,7 +114,6 @@ def forward_map( ) -> AttentionRecurrentStateWithReset: emb, start = x # Add Attention dim for concat - #return emb.reshape(1, -1), start mask = jnp.concatenate([ jnp.zeros((self.window_size - 1), dtype=bool), jnp.ones((1,), dtype=bool) @@ -118,7 +128,8 @@ def forward_map( jnp.zeros((self.window_size - 1, *emb.shape), dtype=emb.dtype), v.reshape(1, -1) ]) - return (key, value, mask), start + ts = jnp.ones((self.window_size), dtype=jnp.int32) + return (key, value, mask, ts), start @jaxtyped(typechecker=typechecker) def backward_map( @@ -129,7 +140,8 @@ def backward_map( ) -> Float[Array, "{self.recurrent_size}"]: emb, start = x state, reset_carry = h - K, V, mask = state + K, V, mask, ts = state + q = self.Q(emb) # B = batch size # S = length of the key/value (source) @@ -139,11 +151,20 @@ def backward_map( # K = number of key/value heads # G = number of groups, which equals to N // K n, k, t, s, h = 1, 1, 1, self.window_size, self.recurrent_size + bias = None + if self.positional_embedding == "alibi": + m = 2 ** -8 + # T-1 to 0 + bias = m * (ts[0] + jnp.arange(-s + 1, 1)) + elif self.positional_embedding == "rope": + K, q = apply_rope(K, q) + + mask = mask.reshape(n, t, s) + bias = bias if bias is None else bias.reshape(n, t, s) K = K.reshape(s, k, h) - q = self.Q(emb).reshape(t, n, h) # Only for current timestep + q = q.reshape(t, n, h) # Only for current timestep V = V.reshape(s, k, h) - mask = mask.reshape(n, t, s) - z = jax.nn.dot_product_attention(q, K, V, mask=mask) + z = jax.nn.dot_product_attention(q, K, V, mask=mask, bias=bias) return z.reshape(-1) @jaxtyped(typechecker=typechecker) diff --git a/memorax/equinox/train_utils.py b/memorax/equinox/train_utils.py index 585cdec..ca137f8 100644 --- a/memorax/equinox/train_utils.py +++ b/memorax/equinox/train_utils.py @@ -298,7 +298,13 @@ def get_residual_memory_models( recurrent_size=recurrent_size, stack_size=4, key=key ), "Attention": lambda recurrent_size, key: Attention( - recurrent_size=recurrent_size, window_size=20, key=key + recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key + ), + "Attention-RoPE": lambda recurrent_size, key: Attention( + recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key + ), + "Attention-ALiBi": lambda recurrent_size, key: Attention( + recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key ), # set actions "GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key), diff --git a/memorax/utils.py b/memorax/utils.py index 7ce518c..954004f 100644 --- a/memorax/utils.py +++ b/memorax/utils.py @@ -5,7 +5,7 @@ from typing import Tuple import jax import jax.numpy as jnp -from jaxtyping import Array, Int, Shaped +from jaxtyping import Array, Int, Shaped, Float def debug_shape(x: jax.Array) -> str: @@ -15,27 +15,120 @@ def debug_shape(x: jax.Array) -> str: return eqx.tree_pprint(jax.tree.map(lambda x: {x.shape: x.dtype}, x)) - -def transformer_positional_encoding( - d_model: int, time_index: Int[Array, ""] -) -> jnp.ndarray: +def apply_rope(keys: Float[Array, "Time Feat"], query: Float[Array, "Feat"]) -> Tuple[Float[Array, "Time Feat"], Float[Array, "Feat"]]: """ - Generate a positional encoding vector for a given time index. - + Applies RoPE assuming contiguous time indices. + + Constraints: + - Keys correspond to time steps [1, 2, ..., T] + - Query corresponds to time step T + Args: - time_index (int): The time step index to encode. - d_model (int): The dimensionality of the encoding vector. - + keys: Array of shape (T, F) + query: Array of shape (F) + Returns: - jnp.ndarray: A positional encoding vector of shape (d_model,). + keys_rope: Embedded keys (T, F) + query_rope: Embedded query (F) """ - position = time_index - div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0) / d_model)) - pos_encoding = jnp.zeros(d_model) - pos_encoding = pos_encoding.at[0::2].set(jnp.sin(position * div_term)) - pos_encoding = pos_encoding.at[1::2].set(jnp.cos(position * div_term)) - return pos_encoding - + T, F = keys.shape + assert F % 2 == 0, "Feature dimension must be even" + + # 1. Generate the Position Indices based on shape T + # Keys are positions 1 to T + key_indices = jnp.arange(1, T + 1, dtype=jnp.float32) + # Query is position T + query_index = jnp.array(T, dtype=jnp.float32) + + # 2. Calculate RoPE Frequencies (Theta) + # Standard formula: theta_i = 10000^(-2i/d) + theta_indices = jnp.arange(0, F, 2) + theta = 1.0 / (10000.0 ** (theta_indices / F)) # Shape: (F/2,) + + # 3. Create Complex Rotation Angles + # Keys: (T, F/2) -> broadcast positions against frequencies + key_angles = jnp.outer(key_indices, theta) + # Query: (F/2,) -> scalar T against frequencies + query_angle = query_index * theta + + # Calculate rotation vectors: e^(i * angle) + key_rotators = jnp.exp(1j * key_angles) + query_rotator = jnp.exp(1j * query_angle) + + # 4. Apply Rotation using Complex Numbers + # Reshape (T, F) -> (T, F/2, 2) and convert to complex + keys_complex = keys.reshape(T, -1, 2) + keys_complex = keys_complex[..., 0] + 1j * keys_complex[..., 1] + + # Reshape (F) -> (F/2, 2) and convert to complex + query_complex = query.reshape(-1, 2) + query_complex = query_complex[..., 0] + 1j * query_complex[..., 1] + + # Multiply (rotate) + keys_out_complex = keys_complex * key_rotators + query_out_complex = query_complex * query_rotator + + # 5. Convert back to Real + keys_rope = jnp.stack([keys_out_complex.real, keys_out_complex.imag], axis=-1).reshape(T, F) + query_rope = jnp.stack([query_out_complex.real, query_out_complex.imag], axis=-1).reshape(F) + + return keys_rope, query_rope + +def apply_sinusoidal_pe(keys: Float[Array, "Time Feat"], query: Float[Array, "Feat"], offset: Int[Array, ""] = jnp.array(0)): + """ + Applies Standard Sinusoidal Positional Encoding with a temporal offset. + + Args: + keys: Array of shape (T, F). + query: Array of shape (F). + offset: (int or scalar) The starting time offset. + If offset=10, keys map to positions 11...10+T. + + Returns: + keys_pe: keys + PE(pos) + query_pe: query + PE(pos) + """ + T, F = keys.shape + # Don't allow python ints which force recompile + assert isinstance(offset, jax.Array), "Offset must be a JAX array scalar." + assert F % 2 == 0, "Feature dimension F must be even." + + # 1. Define Positions with Offset + # Cast to float32 immediately for instruction efficiency in sin/cos later + offset_arr = jnp.array(offset, dtype=jnp.float32) + + # Keys: [1+offset, 2+offset, ..., T+offset] + key_positions = jnp.arange(1, T + 1, dtype=jnp.float32) + offset_arr + key_positions = key_positions[:, None] # Shape (T, 1) for broadcasting + + # Query: T + offset + # corresponds to the last time step in this batch + query_position = (jnp.array(T, dtype=jnp.float32) + offset_arr) + + # 2. Calculate Frequency Divisor + # Note: Standard simplified implementation is just F. + # The exact Vaswani paper uses exp(arange(0, d, 2) * -(log(10000.0) / d)) + + dim_indices = jnp.arange(0, F, 2, dtype=jnp.float32) + div_term = jnp.exp(dim_indices * -(jnp.log(10000.0) / F)) # Shape (F/2,) + + # 3. Generate Embeddings for Keys + # Broadcast (T, 1) * (F/2,) -> (T, F/2) + k_args = key_positions * div_term + + # Interleave Sin/Cos for keys + # Shape: (T, F/2, 2) -> Flatten to (T, F) + pe_keys = jnp.stack([jnp.sin(k_args), jnp.cos(k_args)], axis=-1).reshape(T, F) + + # 4. Generate Embeddings for Query + # Broadcast Scalar * (F/2,) -> (F/2,) + q_args = query_position * div_term + + # Interleave Sin/Cos for query + pe_query = jnp.stack([jnp.sin(q_args), jnp.cos(q_args)], axis=-1).reshape(F) + + # 5. Add to Inputs + return keys + pe_keys, query + pe_query def combine_and_right_align( left_array: Shaped[Array, "Time Feat"], diff --git a/tests/test_initial_input_equinox.py b/tests/test_initial_input_equinox.py index d528097..61824a4 100644 --- a/tests/test_initial_input_equinox.py +++ b/tests/test_initial_input_equinox.py @@ -12,27 +12,29 @@ def get_desired_accuracies(): return { "MLP": 0, "Stack": 0, - "Attention": 0.999, - "DLSE": 0.999, - "FFM": 0.999, - "FART": 0.999, - "FWP": 0.999, - "DeltaNet": 0.999, - "DeltaProduct": 0.999, - "GDN": 0.999, - "LRU": 0.999, - "S6": 0.999, - "LinearRNN": 0.999, - "PSpherical": 0.999, - "GRU": 0.999, + "Attention": 0.99, + "Attention-RoPE": 0.99, + "Attention-ALiBi": 0.99, + "DLSE": 0.99, + "FFM": 0.99, + "FART": 0.99, + "FWP": 0.99, + "DeltaNet": 0.99, + "DeltaProduct": 0.99, + "GDN": 0.99, + "LRU": 0.99, + "S6": 0.99, + "LinearRNN": 0.99, + "PSpherical": 0.99, + "GRU": 0.99, "Elman": 0.55, "ElmanReLU": 0.55, - "Spherical": 0.999, - "NMax": 0.999, - "MGU": 0.999, - "LSTM": 0.999, - "S6D": 0.999, - "S6": 0.999, + "Spherical": 0.99, + "NMax": 0.99, + "MGU": 0.99, + "LSTM": 0.99, + "S6D": 0.99, + "S6": 0.99, } @@ -43,7 +45,7 @@ def ce_loss(y_hat, y): 4, 8, 4 - 1, key=jax.random.key(0), ).items()) def test_initial_input( - model_name, model, epochs=4000, num_seqs=5, seq_len=20, input_dims=4 + model_name, model, epochs=2000, num_seqs=5, seq_len=20, input_dims=4 ): timesteps = num_seqs * seq_len seq_idx = jnp.array([seq_len * i for i in range(num_seqs)])