Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion memorax/equinox/semigroups/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def forward_map(
) -> DeltaFWPRecurrentStateWithReset:
emb, start = x
k = phi(self.K(emb))
k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key
v = self.V(emb)
beta = psi(self.w(emb))
M = jnp.eye(self.recurrent_size) - beta * jnp.outer(k, k)
X = beta * jnp.outer(v, k)
X = beta * jnp.outer(v, k)
return (M, X), start

@jaxtyped(typechecker=typechecker)
Expand All @@ -123,6 +124,7 @@ def backward_map(
emb, start = x
(M, X), reset_flag = h
q = phi(self.Q(emb))
q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query
return self.output(X @ q)

@jaxtyped(typechecker=typechecker)
Expand Down
3 changes: 2 additions & 1 deletion memorax/equinox/semigroups/deltap.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward_map(
) -> DeltaProductRecurrentStateWithReset:
emb, start = x
k = phi(self.K(emb)).reshape(-1, self.rank)
k = k / (1e-8 + jnp.linalg.norm(k, axis=0, keepdims=True))
k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key
v = self.V(emb).reshape(-1, self.rank)
alpha = jax.nn.sigmoid(self.alpha(emb))
beta = psi(self.w(emb)).reshape(-1, self.rank)
Expand All @@ -134,6 +134,7 @@ def backward_map(
emb, start = x
(M, X), reset_flag = h
q = phi(self.Q(emb))
q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query
return self.output(X @ q)

@jaxtyped(typechecker=typechecker)
Expand Down
11 changes: 9 additions & 2 deletions memorax/equinox/semigroups/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
import equinox as eqx
from equinox import nn
from jaxtyping import Array, Float, PRNGKeyArray, Shaped, jaxtyped

Expand All @@ -21,7 +22,7 @@
def phi(x, key=None):
# https://arxiv.org/pdf/2102.11174 uses relu
# https://arxiv.org/pdf/2406.06484 uses silu
return jax.nn.relu(x)
return jax.nn.silu(x)

def psi(x, key=None):
# https://arxiv.org/pdf/2102.11174 uses sigmoid
Expand Down Expand Up @@ -100,7 +101,11 @@ def __init__(self, hidden_size, recurrent_size, key):
self.Q = nn.Linear(hidden_size, recurrent_size, use_bias=False, key=keys[1])
self.V = nn.Linear(hidden_size, recurrent_size, use_bias=False, key=keys[2])
self.w = nn.Linear(hidden_size, 1, key=keys[3])
self.alpha = nn.Linear(hidden_size, 1, key=keys[4])
alpha = nn.Linear(hidden_size, 1, key=keys[4])
# Initialize alpha bias to 4.0 so that sigmoid(alpha) is near 1.0 at init
self.alpha = eqx.tree_at(
lambda l: l.bias, alpha, jnp.full_like(alpha.bias, 4.0)
)
self.output = nn.Linear(recurrent_size, hidden_size, key=keys[5])

@jaxtyped(typechecker=typechecker)
Expand All @@ -109,6 +114,7 @@ def forward_map(
) -> GDNRecurrentStateWithReset:
emb, start = x
k = phi(self.K(emb))
k = k / (jnp.linalg.norm(k) + 1e-6) # normalize key
v = self.V(emb)
beta = psi(self.w(emb))
alpha = jax.nn.sigmoid(self.alpha(emb))
Expand All @@ -126,6 +132,7 @@ def backward_map(
emb, start = x
(M, X), reset_flag = h
q = phi(self.Q(emb))
q = q / (jnp.linalg.norm(q) + 1e-6) # normalize query
return self.output(X @ q)

@jaxtyped(typechecker=typechecker)
Expand Down