From 76ea597b3e3e249e7b1e3ed22ae53475d96b0c02 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Nov 2025 02:50:52 +0800 Subject: [PATCH] DeltaNet fixes --- memorax/equinox/semigroups/delta.py | 4 +++- memorax/equinox/semigroups/deltap.py | 3 ++- memorax/equinox/semigroups/gdn.py | 11 +++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/memorax/equinox/semigroups/delta.py b/memorax/equinox/semigroups/delta.py index 4e99713..3317a3e 100644 --- a/memorax/equinox/semigroups/delta.py +++ b/memorax/equinox/semigroups/delta.py @@ -107,10 +107,11 @@ def forward_map( ) -> DeltaFWPRecurrentStateWithReset: emb, start = x k = phi(self.K(emb)) + k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key v = self.V(emb) beta = psi(self.w(emb)) M = jnp.eye(self.recurrent_size) - beta * jnp.outer(k, k) - X = beta * jnp.outer(v, k) + X = beta * jnp.outer(v, k) return (M, X), start @jaxtyped(typechecker=typechecker) @@ -123,6 +124,7 @@ def backward_map( emb, start = x (M, X), reset_flag = h q = phi(self.Q(emb)) + q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query return self.output(X @ q) @jaxtyped(typechecker=typechecker) diff --git a/memorax/equinox/semigroups/deltap.py b/memorax/equinox/semigroups/deltap.py index 36b0305..df2855c 100644 --- a/memorax/equinox/semigroups/deltap.py +++ b/memorax/equinox/semigroups/deltap.py @@ -112,7 +112,7 @@ def forward_map( ) -> DeltaProductRecurrentStateWithReset: emb, start = x k = phi(self.K(emb)).reshape(-1, self.rank) - k = k / (1e-8 + jnp.linalg.norm(k, axis=0, keepdims=True)) + k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key v = self.V(emb).reshape(-1, self.rank) alpha = jax.nn.sigmoid(self.alpha(emb)) beta = psi(self.w(emb)).reshape(-1, self.rank) @@ -134,6 +134,7 @@ def backward_map( emb, start = x (M, X), reset_flag = h q = phi(self.Q(emb)) + q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query return self.output(X @ q) @jaxtyped(typechecker=typechecker) diff --git a/memorax/equinox/semigroups/gdn.py b/memorax/equinox/semigroups/gdn.py index cc9d18d..263aa28 100644 --- a/memorax/equinox/semigroups/gdn.py +++ b/memorax/equinox/semigroups/gdn.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp from beartype import beartype as typechecker +import equinox as eqx from equinox import nn from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped @@ -21,7 +22,7 @@ def phi(x, key=None): # https://arxiv.org/pdf/2102.11174 uses relu # https://arxiv.org/pdf/2406.06484 uses silu - return jax.nn.relu(x) + return jax.nn.silu(x) def psi(x, key=None): # https://arxiv.org/pdf/2102.11174 uses sigmoid @@ -100,7 +101,11 @@ def __init__(self, hidden_size, recurrent_size, key): self.Q = nn.Linear(hidden_size, recurrent_size, use_bias=False, key=keys[1]) self.V = nn.Linear(hidden_size, recurrent_size, use_bias=False, key=keys[2]) self.w = nn.Linear(hidden_size, 1, key=keys[3]) - self.alpha = nn.Linear(hidden_size, 1, key=keys[4]) + alpha = nn.Linear(hidden_size, 1, key=keys[4]) + # Initialize alpha bias to 4.0 so that sigmoid(alpha) is near 1.0 at init + self.alpha = eqx.tree_at( + lambda l: l.bias, alpha, jnp.full_like(alpha.bias, 4.0) + ) self.output = nn.Linear(recurrent_size, hidden_size, key=keys[5]) @jaxtyped(typechecker=typechecker) @@ -109,6 +114,7 @@ def forward_map( ) -> GDNRecurrentStateWithReset: emb, start = x k = phi(self.K(emb)) + k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key v = self.V(emb) beta = psi(self.w(emb)) alpha = jax.nn.sigmoid(self.alpha(emb)) @@ -126,6 +132,7 @@ def backward_map( emb, start = x (M, X), reset_flag = h q = phi(self.Q(emb)) + q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query return self.output(X @ q) @jaxtyped(typechecker=typechecker)