diff --git a/examples/gemma/helpers.py b/examples/gemma/helpers.py index b9c4195f4..7743563c0 100644 --- a/examples/gemma/helpers.py +++ b/examples/gemma/helpers.py @@ -16,8 +16,9 @@ from __future__ import annotations -from typing import TypeVar from collections.abc import Callable +from typing import Any, TypeVar + import flax from flax import nnx from flax.typing import VariableDict # pylint: disable=g-importing-member,g-multiple-import @@ -43,6 +44,12 @@ def module_from_linen_variables( map_key_fn: None | ( Callable[[tuple[str, ...]], tuple[str | int, ...]] ) = None, + assign_val_fn: None | ( + Callable[ + [dict[tuple[str, ...], Any], tuple[str | int, ...], VariableDict], + dict[tuple[str, ...], Any], + ] + ) = None, ) -> M: """Returns an `nnx.Module` initialized with the `variables` of a linen module. @@ -60,6 +67,16 @@ def module_from_linen_variables( def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]: return path[1:] if 'params' in variables else path + if assign_val_fn is None: + + def assign_val_fn( + state: dict[tuple[str, ...], Any], + mapped_path: tuple[str | int, ...], + val: Any, + ) -> dict[tuple[str, ...], Any]: + state[mapped_path].value = val + return state + mdl: M = nnx.eval_shape(module_factory) graph_def, state = nnx.split(mdl) state = dict(state.flat_state()) @@ -70,7 +87,7 @@ def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]: f"'{mdl.__class__.__name__}.{_flatten_path(mapped_path)}' doesn't " f' exist (original path={path}).' ) - state[mapped_path].value = val + state = assign_val_fn(state, mapped_path, val) state = nnx.State.from_flat_path(state) return nnx.merge(graph_def, state) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 38708803c..ae0b08ba9 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -243,27 +243,36 @@ def __init__( *, rngs: nnx.Rngs, ): - self.gating_einsum = nnx.Param( - nn.initializers.zeros_init()( - rngs.params(), - ((2, features, hidden_dim)), - ) + self.gate_proj = nnx.Linear( + in_features=features, + out_features=hidden_dim, + use_bias=False, + rngs=rngs, + kernel_init=nn.initializers.zeros_init(), ) - self.linear = nnx.Param( - nn.initializers.zeros_init()( - rngs.params(), - (hidden_dim, features), - ) + self.up_proj = nnx.Linear( + in_features=features, + out_features=hidden_dim, + use_bias=False, + rngs=rngs, + kernel_init=nn.initializers.zeros_init(), + ) + self.down_proj = nnx.Linear( + in_features=hidden_dim, + out_features=features, + use_bias=False, + rngs=rngs, + kernel_init=nn.initializers.zeros_init(), ) def __call__(self, x: ArrayLike) -> Array: - ff_gate = jnp.dot(x, self.gating_einsum.value[0]) + ff_gate = self.gate_proj(x) gate_value = nnx.gelu(ff_gate) - ff1 = jnp.dot(x, self.gating_einsum.value[1]) + ff1 = self.up_proj(x) activations = gate_value * ff1 - outputs = jnp.dot(activations, self.linear.value) + outputs = self.down_proj(activations) return outputs diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index b585cb3c5..7439585cc 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -223,8 +223,9 @@ def test_ffw( hidden_dim=hidden_dim, rngs=nnx.Rngs(params=0), ) - ffw.gating_einsum.value = jnp.ones((2, features, hidden_dim)) - ffw.linear.value = jnp.ones((hidden_dim, features)) + ffw.gate_proj.kernel.value = jnp.ones((features, hidden_dim)) + ffw.up_proj.kernel.value = jnp.ones((features, hidden_dim)) + ffw.down_proj.kernel.value = jnp.ones((hidden_dim, features)) with jax.default_matmul_precision('float32'): outputs = ffw(inputs) diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index b4543316e..f20da6df8 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -16,8 +16,9 @@ from __future__ import annotations -import dataclasses from collections.abc import Iterable +import dataclasses +from typing import Any from flax import nnx import helpers @@ -180,22 +181,45 @@ def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: assert not prefix, prefix new_key.append('layers') new_key.append(int(suffix)) + elif k == 'gating_einsum': + new_key.append('gate_proj') + new_key.append('kernel') + elif k == 'linear': + new_key.append('down_proj') + new_key.append('kernel') else: new_key.append(k) return tuple(new_key) +def _assign_linen_params_to_nnx_state( + state: dict[tuple[str, ...], Any], + mapped_path: tuple[str | int, ...], + val: Any, +) -> dict[tuple[str, ...], Any]: + if 'gate_proj' in mapped_path: + state[mapped_path].value = val[0] + state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] + else: + state[mapped_path].value = val + return state + + class Transformer(nnx.Module): """Gemma transformer.""" @classmethod - def from_params(cls, params: params_lib.Params) -> Transformer: - config = TransformerConfig.from_params(params) + def from_params( + cls, params: params_lib.Params, config: None | TransformerConfig = None + ) -> Transformer: + if config is None: + config = TransformerConfig.from_params(params) return helpers.module_from_linen_variables( module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), variables=params['transformer'], map_key_fn=_map_linen_var_names, + assign_val_fn=_assign_linen_params_to_nnx_state, ) def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs): diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index efb0cf222..02914e753 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -14,6 +14,7 @@ # ============================================================================ """Tests for the Gemma transformer.""" +from collections import defaultdict from absl.testing import absltest from absl.testing import parameterized from flax import nnx @@ -23,6 +24,64 @@ import numpy as np +def create_fake_params(config: transformer_lib.TransformerConfig): + def nested_defaultdict(): + return defaultdict(nested_defaultdict) + + res = nested_defaultdict() + res['transformer'] = nested_defaultdict() + params = res['transformer'] + # 1. embedding params + params['embedder']['input_embedding'] = jnp.ones( + (config.num_embed, config.embed_dim) + ) + # 2. final norm params + params['final_norm'] = {'scale': jnp.ones((config.embed_dim,))} + + # 3. attention block params + for layer_idx in range(config.num_layers): + params[f'layer_{layer_idx}']['attn']['attn_vec_einsum']['w'] = jnp.ones( + (config.num_heads, config.head_dim, config.embed_dim) + ) + if config.num_heads == config.num_kv_heads: + params[f'layer_{layer_idx}']['attn']['qkv_einsum']['w'] = jnp.ones( + (3, config.num_heads, config.embed_dim, config.head_dim) + ) + else: + params[f'layer_{layer_idx}']['attn']['q_einsum']['w'] = jnp.ones( + (config.num_heads, config.embed_dim, config.head_dim) + ) + params[f'layer_{layer_idx}']['attn']['kv_einsum']['w'] = jnp.ones( + (config.num_kv_heads, config.embed_dim, config.head_dim) + ) + + # 4. feedforward block params + params[f'layer_{layer_idx}']['mlp']['gating_einsum'] = jnp.ones( + (2, config.embed_dim, config.hidden_dim) + ) + params[f'layer_{layer_idx}']['mlp']['linear'] = jnp.ones( + (config.hidden_dim, config.embed_dim) + ) + + # 5. layer norm params + params[f'layer_{layer_idx}']['pre_attention_norm']['scale'] = jnp.ones(( + config.embed_dim, + )) + params[f'layer_{layer_idx}']['pre_ffw_norm']['scale'] = jnp.ones(( + config.embed_dim, + )) + + if config.use_post_attn_norm: + params[f'layer_{layer_idx}']['post_attn_norm']['scale'] = jnp.ones(( + config.embed_dim, + )) + if config.use_post_ffw_norm: + params[f'layer_{layer_idx}']['post_ffw_norm']['scale'] = jnp.ones(( + config.embed_dim, + )) + return res + + class TransformerTest(parameterized.TestCase): @parameterized.parameters( @@ -290,6 +349,49 @@ def test_attention_types( ) self.assertTrue(cache) + @parameterized.parameters( + dict( + config=transformer_lib.TransformerConfig( + num_layers=2, + num_embed=4, + embed_dim=2, + hidden_dim=12, + num_heads=3, + head_dim=4, + num_kv_heads=3, + final_logit_softcap=None, + attention_types=[modules.AttentionType.GLOBAL] * 2, + use_post_attn_norm=False, + use_post_ffw_norm=False, + ), + ), + dict( + config=transformer_lib.TransformerConfig( + num_layers=2, + num_embed=4, + embed_dim=2, + hidden_dim=12, + num_heads=3, + head_dim=4, + num_kv_heads=3, + final_logit_softcap=None, + attention_types=[modules.AttentionType.GLOBAL] * 2, + use_post_attn_norm=True, + use_post_ffw_norm=True, + ), + ), + ) + def test_load_from_params(self, config): + params = create_fake_params(config) + transformer = transformer_lib.Transformer.from_params(params, config) + logits, _ = transformer( + last_tokens=jnp.tile(jnp.arange(3), (2, 1)), + positions=jnp.tile(jnp.arange(3), (2, 1)), + cache=None, + attention_mask=jnp.ones((2, 1, 3), dtype=jnp.bool), + ) + self.assertEqual(logits.shape, (2, 3, 4)) + if __name__ == '__main__': absltest.main()