diff --git a/README.md b/README.md index 34099f7..9a3dd55 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,6 @@ We implement both linear and log-complexity recurrent models. |------|--------------------------|-------|------| | Linear Recurrent Unit | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2303.06349) | [[code]](memorax/equinox/semigroups/lru.py) | | Selective State Space Model (S6) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6.py) | -| Diagonal Selective State Space Model (S6D) | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2312.00752) | [[code]](memorax/equinox/semigroups/s6d.py) | | Linear Recurrent Neural Network | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/1709.04057) | [[code]](memorax/equinox/semigroups/lrnn.py) | | Fast Autoregressive Transformer | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2006.16236) | [[code]](memorax/equinox/semigroups/fart.py) | | Fast and Forgetful Memory | $O(\log{n})$ | [[paper]](https://arxiv.org/abs/2310.04128) | [[code]](memorax/equinox/semigroups/ffm.py) | diff --git a/memorax/equinox/semigroups/__init__.py b/memorax/equinox/semigroups/__init__.py index ceee906..b33c23a 100644 --- a/memorax/equinox/semigroups/__init__.py +++ b/memorax/equinox/semigroups/__init__.py @@ -14,6 +14,5 @@ + `memorax.equinox.semigroups.lrnn` provides a basic linear recurrence. + `memorax.equinox.semigroups.mlp` provides an MLP (no memory) for completeness. + `memorax.equinox.semigroups.s6` provides the Selective State Space Model (Mamba) layer. -+ `memorax.equinox.semigroups.s6d` provides a diagonal variant of the Selective State Space Model (Mamba) layer. + `memorax.equinox.semigroups.spherical` provides Rotational RNN layer (spherical projection). """ \ No newline at end of file diff --git a/memorax/equinox/semigroups/s6.py b/memorax/equinox/semigroups/s6.py index 56b7406..a8819c1 100644 --- a/memorax/equinox/semigroups/s6.py +++ b/memorax/equinox/semigroups/s6.py @@ -16,18 +16,10 @@ S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag] -# Inits -@jaxtyped(typechecker=typechecker) -def glorot_init( - key: PRNGKeyArray, shape: Tuple[int, ...], normalization: Scalar = jnp.array(1.0) -): - return jax.random.normal(key=key, shape=shape) / normalization - - class S6Semigroup(Semigroup): - """The full-rank S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752. + """The S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752. - This is a S5/LRU recurrent update with a learnable timestep parameter. """ + This is an S5/LRU recurrent update with a learnable timestep parameter. """ recurrent_size: int @@ -58,8 +50,7 @@ def __call__( class S6(GRAS): """ - The full-rank S6 SSM, an SSM with a trainable dt. - The recurrent matrix A is diagonal, but the B, C matrices are full-rank. + The S6 SSM with diagonal recurrence. We base this on the LRU, and add a trainable dt. You might want to use this as a building block for a more complex model. """ @@ -93,8 +84,8 @@ def __init__(self, recurrent_size, hidden_size, key): self.scan = semigroup_scan self.A_log = jax.random.normal(keys[0], (self.recurrent_size,)) - self.B = nn.Linear(self.hidden_size, self.recurrent_size * self.recurrent_size, key=keys[1], use_bias=False) - self.C = nn.Linear(self.recurrent_size, self.hidden_size, key=keys[2], use_bias=False) + self.B = nn.Linear(self.hidden_size, self.recurrent_size, key=keys[1]) + self.C = nn.Linear(self.recurrent_size, self.hidden_size, key=keys[2]) self.dt = nn.Sequential([ nn.Linear(self.hidden_size, self.recurrent_size, key=keys[3]), nn.Lambda(jax.nn.softplus) @@ -106,12 +97,12 @@ def forward_map(self, x: Input, key: Optional[Shaped[PRNGKeyArray, ""]] = None): dt = self.dt(emb) A = -jnp.exp(self.A_log) A_bar = jnp.exp(dt * A) - B = self.B(emb).reshape(self.recurrent_size, self.recurrent_size) - # NOTE: A is diagonal so we can compute B_bar more simply than the mamba paper + B = self.B(emb) + # NOTE: A and B are diagonal so we can compute B_bar more simply than the mamba paper # Thankfully, inv(A) is just 1 / A if A is diagonal # Furthermore the dt's cancel: 1 / (dt A) with dt B - B_bar = jnp.diag(1 / A * (A_bar - 1.0)) @ B - Bu = B_bar @ emb + B_bar = 1 / A * (A_bar - 1.0) * B + Bu = B_bar * emb return (A_bar, Bu), start @jaxtyped(typechecker=typechecker) diff --git a/memorax/equinox/semigroups/s6d.py b/memorax/equinox/semigroups/s6d.py deleted file mode 100644 index d5a5641..0000000 --- a/memorax/equinox/semigroups/s6d.py +++ /dev/null @@ -1,131 +0,0 @@ -# https://github.com/NicolasZucchet/minimal-S6/blob/main/S6/model.py -from beartype.typing import Callable, Optional, Tuple - -import jax -import jax.numpy as jnp -from equinox import nn -from beartype import beartype as typechecker -from jaxtyping import Array, Complex, Float, PRNGKeyArray, Scalar, Shaped, jaxtyped - -from memorax.equinox.groups import BinaryAlgebra, Semigroup, Resettable -from memorax.equinox.gras import GRAS -from memorax.mtypes import Input, StartFlag -from memorax.equinox.scans import semigroup_scan - -S6RecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] -S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag] - - -# Inits -@jaxtyped(typechecker=typechecker) -def glorot_init( - key: PRNGKeyArray, shape: Tuple[int, ...], normalization: Scalar = jnp.array(1.0) -): - return jax.random.normal(key=key, shape=shape) / normalization - - -class S6DSemigroup(Semigroup): - """The S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752. - - This is an S5/LRU recurrent update with a learnable timestep parameter. """ - - recurrent_size: int - - def __init__( - self, - recurrent_size: int, - ): - self.recurrent_size = recurrent_size - - @jaxtyped(typechecker=typechecker) - def initialize_carry( - self, key: Optional[Shaped[PRNGKeyArray, ""]] = None - ) -> S6RecurrentState: - # Represent a diagonal matrix as a vector - return ( - jnp.ones((self.recurrent_size,)), - jnp.zeros((self.recurrent_size,)) - ) - - @jaxtyped(typechecker=typechecker) - def __call__( - self, carry: S6RecurrentState, input: S6RecurrentState - ) -> S6RecurrentState: - A_i, bu_i = carry - A_j, bu_j = input - return A_j * A_i, A_j * bu_i + bu_j - - -class S6D(GRAS): - """ - The S6 SSM. We base this on the LRU, and add a trainable dt. - - You might want to use this as a building block for a more complex model. - """ - - algebra: BinaryAlgebra - scan: Callable[ - [ - Callable[ - [S6RecurrentStateWithReset, S6RecurrentStateWithReset], - S6RecurrentStateWithReset, - ], - S6RecurrentStateWithReset, - S6RecurrentStateWithReset, - ], - S6RecurrentStateWithReset, - ] - A_log: Float[Array, "Recurrent"] - B: nn.Linear - C: nn.Linear - dt: nn.Linear - - hidden_size: int # input and output dimensions - recurrent_size: int # hidden state dimension - - def __init__(self, recurrent_size, hidden_size, key): - keys = jax.random.split(key, 4) - self.recurrent_size = recurrent_size - self.hidden_size = hidden_size - unwrapped = S6DSemigroup(recurrent_size) - self.algebra = Resettable(unwrapped) - self.scan = semigroup_scan - - self.A_log = jax.random.normal(keys[0], (self.recurrent_size,)) - self.B = nn.Linear(self.hidden_size, self.recurrent_size, key=keys[1]) - self.C = nn.Linear(self.recurrent_size, self.hidden_size, key=keys[2]) - self.dt = nn.Sequential([ - nn.Linear(self.hidden_size, 1, key=keys[3]), - nn.Lambda(jax.nn.softplus) - ]) - - @jaxtyped(typechecker=typechecker) - def forward_map(self, x: Input, key: Optional[Shaped[PRNGKeyArray, ""]] = None): - emb, start = x - dt = self.dt(emb) - A = -jnp.exp(self.A_log) - A_bar = jnp.exp(dt * A) - B = self.B(emb) - B_bar = dt * B - Bu = B_bar * emb - return (A_bar, Bu), start - - @jaxtyped(typechecker=typechecker) - def backward_map( - self, - h: S6RecurrentStateWithReset, - x: Input, - key: Optional[Shaped[PRNGKeyArray, ""]] = None, - ) -> Float[Array, "{self.recurrent_size}"]: - state, reset_flag = h - emb, start = x - lambdas, lambda_x_Bu = state - C = self.C(emb) - out = C * lambda_x_Bu - return out - - @jaxtyped(typechecker=typechecker) - def initialize_carry( - self, key: Optional[Shaped[PRNGKeyArray, ""]] = None - ) -> S6RecurrentStateWithReset: - return self.algebra.initialize_carry(key) diff --git a/memorax/equinox/train_utils.py b/memorax/equinox/train_utils.py index ca137f8..cff2bee 100644 --- a/memorax/equinox/train_utils.py +++ b/memorax/equinox/train_utils.py @@ -2,7 +2,7 @@ It includes loss functions, accuracy metrics, and training loops. It also provides a straightforward way to construct multi-layer recurrent models.""" -from beartype.typing import Callable, Dict, Tuple +from beartype.typing import Callable, Dict, Tuple, Any, Optional import equinox as eqx import jax @@ -25,7 +25,6 @@ from memorax.equinox.semigroups.nmax import NMax, NMaxSemigroup from memorax.equinox.semigroups.spherical import PSpherical, PSphericalSemigroup from memorax.equinox.semigroups.s6 import S6, S6Semigroup -from memorax.equinox.semigroups.s6d import S6D, S6DSemigroup from memorax.equinox.semigroups.delta import DeltaNet, DeltaNetSemigroup from memorax.equinox.semigroups.deltap import DeltaProduct, DeltaProductSemigroup from memorax.equinox.semigroups.gdn import GDN, GDNSemigroup @@ -226,7 +225,6 @@ def get_semigroups( "LinearRNN": LinearRNNSemigroup(recurrent_size), "LRU": LRUSemigroup(recurrent_size), "S6": S6Semigroup(recurrent_size), - "S6D": S6DSemigroup(recurrent_size), "NMax": NMaxSemigroup(recurrent_size), "FWP": FWPSemigroup(recurrent_size), "DeltaNet": DeltaNetSemigroup(recurrent_size), @@ -244,81 +242,83 @@ def get_residual_memory_models( models: str = "all", *, key: jax.random.PRNGKey, + layer_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict] = None, ) -> Dict[str, Module]: """Returns a dictionary of models, correponding to all semigroups and set actions. This returns a dictionary of models, each consisting of multiple recurrent cells with residual and DenseNet connections between them. """ + layer_kwargs = layer_kwargs or {} + model_kwargs = model_kwargs or {} layers = { # for debug "MLP": lambda recurrent_size, key: MLP( - recurrent_size=recurrent_size, key=key + recurrent_size=recurrent_size, key=key, **layer_kwargs.get("MLP", {}) ), # semigroups "NMax": lambda recurrent_size, key: NMax( recurrent_size=recurrent_size, key=key ), "FART": lambda recurrent_size, key: FART( - hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key + hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("FART", {}) ), "FWP": lambda recurrent_size, key: FWP( - hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key + hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("FWP", {}) ), "DeltaNet": lambda recurrent_size, key: DeltaNet( - hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key + hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("DeltaNet", {}) ), "DeltaProduct": lambda recurrent_size, key: DeltaProduct( - hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), rank=4, key=key + hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), rank=4, key=key, **layer_kwargs.get("DeltaProduct", {}) ), "GDN": lambda recurrent_size, key: GDN( - hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key + hidden_size=recurrent_size, recurrent_size=round(recurrent_size ** 0.5), key=key, **layer_kwargs.get("GDN", {}) ), "FFM": lambda recurrent_size, key: FFM( - hidden_size=recurrent_size, context_size=recurrent_size//4, trace_size=4, key=key - ), - "S6D": lambda recurrent_size, key: S6D( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key + hidden_size=recurrent_size, context_size=recurrent_size//4, trace_size=4, key=key, **layer_kwargs.get("FFM", {}) ), "S6": lambda recurrent_size, key: S6( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key + hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("S6", {}) ), "PSpherical": lambda recurrent_size, key: PSpherical( recurrent_size=round(recurrent_size ** 0.5), hidden_size=recurrent_size, - key=key + key=key, + **layer_kwargs.get("PSpherical", {}) ), "LRU": lambda recurrent_size, key: LRU( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key + hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LRU", {}) ), "LinearRNN": lambda recurrent_size, key: LinearRecurrent( - recurrent_size=recurrent_size, key=key + recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LinearRNN", {}) ), "Stack": lambda recurrent_size, key: Stack( - recurrent_size=recurrent_size, stack_size=4, key=key + recurrent_size=recurrent_size, stack_size=4, key=key, **layer_kwargs.get("Stack", {}) ), "Attention": lambda recurrent_size, key: Attention( - recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key + recurrent_size=recurrent_size, window_size=20, positional_embedding=None, key=key, **layer_kwargs.get("Attention", {}) ), "Attention-RoPE": lambda recurrent_size, key: Attention( - recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key + recurrent_size=recurrent_size, window_size=20, positional_embedding="rope", key=key, **layer_kwargs.get("Attention-RoPE", {}) ), "Attention-ALiBi": lambda recurrent_size, key: Attention( - recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key + recurrent_size=recurrent_size, window_size=20, positional_embedding="alibi", key=key, **layer_kwargs.get("Attention-ALiBi", {}) ), # set actions - "GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key), + "GRU": lambda recurrent_size, key: GRU(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("GRU", {})), "Elman": lambda recurrent_size, key: Elman( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key + hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("Elman", {}) ), "ElmanReLU": lambda recurrent_size, key: Elman( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, activation=jax.nn.relu, + hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, activation=jax.nn.relu, **layer_kwargs.get("ElmanReLU", {}) ), "Spherical": lambda recurrent_size, key: Spherical( - hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key + hidden_size=recurrent_size, recurrent_size=recurrent_size, key=key, **layer_kwargs.get("Spherical", {}) ), - "MGU": lambda recurrent_size, key: MGU(recurrent_size=recurrent_size, key=key), - "LSTM": lambda recurrent_size, key: LSTM(recurrent_size=recurrent_size, key=key), + "MGU": lambda recurrent_size, key: MGU(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("MGU", {})), + "LSTM": lambda recurrent_size, key: LSTM(recurrent_size=recurrent_size, key=key, **layer_kwargs.get("LSTM", {})), } if models == "all": return { @@ -329,6 +329,7 @@ def get_residual_memory_models( output_size=output, num_layers=num_layers, key=key, + **model_kwargs, ) for name, fn in layers.items() } @@ -341,6 +342,7 @@ def get_residual_memory_models( output_size=output, num_layers=num_layers, key=key, + **model_kwargs, ) for name in models } \ No newline at end of file diff --git a/memorax/linen/semigroups/s6d.py b/memorax/linen/semigroups/s6.py similarity index 86% rename from memorax/linen/semigroups/s6d.py rename to memorax/linen/semigroups/s6.py index f3e4400..a670736 100644 --- a/memorax/linen/semigroups/s6d.py +++ b/memorax/linen/semigroups/s6.py @@ -13,11 +13,11 @@ from memorax.linen.gras import GRAS from memorax.linen.scans import semigroup_scan -S6DRecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] -S6DRecurrentStateWithReset = Tuple[S6DRecurrentState, StartFlag] +S6RecurrentState = Tuple[Float[Array, "Recurrent"], Float[Array, "Recurrent"]] +S6RecurrentStateWithReset = Tuple[S6RecurrentState, StartFlag] -class S6DSemigroup(Semigroup): +class S6Semigroup(Semigroup): """The diagonal S6 semigroup (recurrent update) from https://arxiv.org/abs/2312.00752. This is a diagonal S5/LRU recurrent update with a learnable timestep parameter. """ @@ -27,7 +27,7 @@ class S6DSemigroup(Semigroup): @jaxtyped(typechecker=typechecker) def initialize_carry( self, key: Optional[Shaped[PRNGKeyArray, ""]] = None - ) -> S6DRecurrentState: + ) -> S6RecurrentState: # Represent a diagonal matrix as a vector return ( jnp.ones((self.recurrent_size,)), @@ -35,7 +35,7 @@ def initialize_carry( ) @nn.nowrap - def zero_carry(self) -> S6DRecurrentState: + def zero_carry(self) -> S6RecurrentState: return ( jnp.zeros((self.recurrent_size,)), jnp.zeros((self.recurrent_size)), @@ -44,8 +44,8 @@ def zero_carry(self) -> S6DRecurrentState: @jaxtyped(typechecker=typechecker) @nn.compact def __call__( - self, carry: S6DRecurrentState, input: S6DRecurrentState - ) -> S6DRecurrentState: + self, carry: S6RecurrentState, input: S6RecurrentState + ) -> S6RecurrentState: # Ax + Bu, but A is diagonal, and we treat it as a vector # So we can be more efficient by writing Ax as vec(A) * x A_i, bu_i = carry @@ -53,7 +53,7 @@ def __call__( return A_j * A_i, A_j * bu_i + bu_j -class S6D(GRAS): +class S6(GRAS): """ The diagonal S6 SSM, an SSM with a trainable dt. @@ -93,7 +93,7 @@ def forward_map(self, x: Input, key: Optional[Shaped[PRNGKeyArray, ""]] = None): @jaxtyped(typechecker=typechecker) def backward_map( self, - h: S6DRecurrentStateWithReset, + h: S6RecurrentStateWithReset, x: Input, key: Optional[Shaped[PRNGKeyArray, ""]] = None, ) -> Float[Array, "{self.recurrent_size}"]: @@ -107,16 +107,16 @@ def backward_map( @jaxtyped(typechecker=typechecker) def initialize_carry( self, key: Optional[Shaped[PRNGKeyArray, ""]] = None - ) -> S6DRecurrentStateWithReset: + ) -> S6RecurrentStateWithReset: return self.algebra.initialize_carry(key) @nn.nowrap - def zero_carry(self) -> S6DRecurrentStateWithReset: + def zero_carry(self) -> S6RecurrentStateWithReset: return self.algebra.zero_carry() @staticmethod def default_algebra(**kwargs): - return Resettable(S6DSemigroup(**kwargs)) + return Resettable(S6Semigroup(**kwargs)) @staticmethod def default_scan(): diff --git a/memorax/linen/train_utils.py b/memorax/linen/train_utils.py index 083e60f..0ddb93a 100644 --- a/memorax/linen/train_utils.py +++ b/memorax/linen/train_utils.py @@ -2,7 +2,8 @@ It includes loss functions, accuracy metrics, and training loops. It also provides a straightforward way to construct multi-layer recurrent models.""" -from beartype.typing import Callable, Dict, Tuple +from typing import Any +from beartype.typing import Callable, Dict, Tuple, Optional, Any import jax import jax.numpy as jnp @@ -14,7 +15,7 @@ from memorax.linen.models.residual import ResidualModel from memorax.linen.semigroups.fart import FARTSemigroup, FART from memorax.linen.semigroups.lru import LRUSemigroup, LRU -from memorax.linen.semigroups.s6d import S6DSemigroup, S6D +from memorax.linen.semigroups.s6 import S6Semigroup, S6 def add_batch_dim(h, batch_size: int, axis: int = 0) -> Shaped[Array, "Batch ..."]: @@ -85,45 +86,68 @@ def get_semigroups( return { "FART": FARTSemigroup(recurrent_size), "LRU": LRUSemigroup(recurrent_size), - "S6D": S6DSemigroup(recurrent_size), + "S6": S6Semigroup(recurrent_size), } def get_residual_memory_models( hidden: int, output: int, num_layers: int = 2, + models: str = "all", + layer_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict] = None ) -> Dict: + """Constructs a trunk of stacked memory cells.""" + layer_kwargs = layer_kwargs or {} + model_kwargs = model_kwargs or {} layers = { "FART": lambda recurrent_size: FART( algebra=FART.default_algebra(recurrent_size=round(recurrent_size**0.5)), scan=FART.default_scan(), hidden_size=recurrent_size, recurrent_size=round(recurrent_size**0.5), + **layer_kwargs.get("FART", {}) ), "LRU": lambda recurrent_size: LRU( algebra=LRU.default_algebra(recurrent_size=recurrent_size), scan=LRU.default_scan(), hidden_size=recurrent_size, recurrent_size=recurrent_size, + **layer_kwargs.get("LRU", {}) ), - "S6D": lambda recurrent_size: S6D( - algebra=S6D.default_algebra(recurrent_size=recurrent_size), - scan=S6D.default_scan(), + "S6": lambda recurrent_size: S6( + algebra=S6.default_algebra(recurrent_size=recurrent_size), + scan=S6.default_scan(), hidden_size=recurrent_size, recurrent_size=recurrent_size, + **layer_kwargs.get("S6", {}) ), "GRU": lambda recurrent_size: GRU( algebra=GRU.default_algebra(recurrent_size=recurrent_size), scan=GRU.default_scan(), recurrent_size=recurrent_size, + **layer_kwargs.get("GRU", {}) ), } - return { - name: ResidualModel( - make_layer_fn=fn, - recurrent_size=hidden, - output_size=output, - num_layers=num_layers, - ) - for name, fn in layers.items() - } + if models == "all": + return { + name: ResidualModel( + make_layer_fn=fn, + recurrent_size=hidden, + output_size=output, + num_layers=num_layers, + **model_kwargs, + ) + for name, fn in layers.items() + } + else: + return { + name: ResidualModel( + make_layer_fn=layers[name], + recurrent_size=hidden, + output_size=output, + num_layers=num_layers, + **model_kwargs, + ) + for name in models + } \ No newline at end of file diff --git a/tests/test_initial_input_linen.py b/tests/test_initial_input_linen.py index f5f75ad..0a8ed0b 100644 --- a/tests/test_initial_input_linen.py +++ b/tests/test_initial_input_linen.py @@ -11,7 +11,7 @@ def get_desired_accuracies(): return { "LRU": 0.999, - "S6D": 0.999, + "S6": 0.999, "FART": 0.999, "GRU": 0.999, }