Skip to content

Commit

Permalink
make gemma FFW LoRA friendly
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721490347
  • Loading branch information
Flax Team committed Jan 30, 2025
1 parent 881685c commit 7f1567c
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 20 deletions.
21 changes: 19 additions & 2 deletions examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from __future__ import annotations

from typing import TypeVar
from collections.abc import Callable
from typing import Any, TypeVar

import flax
from flax import nnx
from flax.typing import VariableDict # pylint: disable=g-importing-member,g-multiple-import
Expand All @@ -43,6 +44,12 @@ def module_from_linen_variables(
map_key_fn: None | (
Callable[[tuple[str, ...]], tuple[str | int, ...]]
) = None,
assign_val_fn: None | (
Callable[
[dict[tuple[str, ...], Any], tuple[str | int, ...], VariableDict],
dict[tuple[str, ...], Any],
]
) = None,
) -> M:
"""Returns an `nnx.Module` initialized with the `variables` of a linen module.
Expand All @@ -60,6 +67,16 @@ def module_from_linen_variables(
def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]:
return path[1:] if 'params' in variables else path

if assign_val_fn is None:

def assign_val_fn(
state: dict[tuple[str, ...], Any],
mapped_path: tuple[str | int, ...],
val: Any,
) -> dict[tuple[str, ...], Any]:
state[mapped_path].value = val
return state

mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = dict(state.flat_state())
Expand All @@ -70,7 +87,7 @@ def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]:
f"'{mdl.__class__.__name__}.{_flatten_path(mapped_path)}' doesn't "
f' exist (original path={path}).'
)
state[mapped_path].value = val
state = assign_val_fn(state, mapped_path, val)
state = nnx.State.from_flat_path(state)

return nnx.merge(graph_def, state)
35 changes: 22 additions & 13 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,27 +243,36 @@ def __init__(
*,
rngs: nnx.Rngs,
):
self.gating_einsum = nnx.Param(
nn.initializers.zeros_init()(
rngs.params(),
((2, features, hidden_dim)),
)
self.gate_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
)
self.linear = nnx.Param(
nn.initializers.zeros_init()(
rngs.params(),
(hidden_dim, features),
)
self.up_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
)
self.down_proj = nnx.Linear(
in_features=hidden_dim,
out_features=features,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
)

def __call__(self, x: ArrayLike) -> Array:
ff_gate = jnp.dot(x, self.gating_einsum.value[0])
ff_gate = self.gate_proj(x)
gate_value = nnx.gelu(ff_gate)

ff1 = jnp.dot(x, self.gating_einsum.value[1])
ff1 = self.up_proj(x)
activations = gate_value * ff1

outputs = jnp.dot(activations, self.linear.value)
outputs = self.down_proj(activations)
return outputs


Expand Down
5 changes: 3 additions & 2 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ def test_ffw(
hidden_dim=hidden_dim,
rngs=nnx.Rngs(params=0),
)
ffw.gating_einsum.value = jnp.ones((2, features, hidden_dim))
ffw.linear.value = jnp.ones((hidden_dim, features))
ffw.gate_proj.kernel.value = jnp.ones((features, hidden_dim))
ffw.up_proj.kernel.value = jnp.ones((features, hidden_dim))
ffw.down_proj.kernel.value = jnp.ones((hidden_dim, features))

with jax.default_matmul_precision('float32'):
outputs = ffw(inputs)
Expand Down
30 changes: 27 additions & 3 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from __future__ import annotations

import dataclasses
from collections.abc import Iterable
import dataclasses
from typing import Any

from flax import nnx
import helpers
Expand Down Expand Up @@ -180,22 +181,45 @@ def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
assert not prefix, prefix
new_key.append('layers')
new_key.append(int(suffix))
elif k == 'gating_einsum':
new_key.append('gate_proj')
new_key.append('kernel')
elif k == 'linear':
new_key.append('down_proj')
new_key.append('kernel')
else:
new_key.append(k)

return tuple(new_key)


def _assign_linen_params_to_nnx_state(
state: dict[tuple[str, ...], Any],
mapped_path: tuple[str | int, ...],
val: Any,
) -> dict[tuple[str, ...], Any]:
if 'gate_proj' in mapped_path:
state[mapped_path].value = val[0]
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
else:
state[mapped_path].value = val
return state


class Transformer(nnx.Module):
"""Gemma transformer."""

@classmethod
def from_params(cls, params: params_lib.Params) -> Transformer:
config = TransformerConfig.from_params(params)
def from_params(
cls, params: params_lib.Params, config: None | TransformerConfig = None
) -> Transformer:
if config is None:
config = TransformerConfig.from_params(params)
return helpers.module_from_linen_variables(
module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)),
variables=params['transformer'],
map_key_fn=_map_linen_var_names,
assign_val_fn=_assign_linen_params_to_nnx_state,
)

def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
Expand Down
102 changes: 102 additions & 0 deletions examples/gemma/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ============================================================================
"""Tests for the Gemma transformer."""

from collections import defaultdict
from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
Expand All @@ -23,6 +24,64 @@
import numpy as np


def create_fake_params(config: transformer_lib.TransformerConfig):
def nested_defaultdict():
return defaultdict(nested_defaultdict)

res = nested_defaultdict()
res['transformer'] = nested_defaultdict()
params = res['transformer']
# 1. embedding params
params['embedder']['input_embedding'] = jnp.ones(
(config.num_embed, config.embed_dim)
)
# 2. final norm params
params['final_norm'] = {'scale': jnp.ones((config.embed_dim,))}

# 3. attention block params
for layer_idx in range(config.num_layers):
params[f'layer_{layer_idx}']['attn']['attn_vec_einsum']['w'] = jnp.ones(
(config.num_heads, config.head_dim, config.embed_dim)
)
if config.num_heads == config.num_kv_heads:
params[f'layer_{layer_idx}']['attn']['qkv_einsum']['w'] = jnp.ones(
(3, config.num_heads, config.embed_dim, config.head_dim)
)
else:
params[f'layer_{layer_idx}']['attn']['q_einsum']['w'] = jnp.ones(
(config.num_heads, config.embed_dim, config.head_dim)
)
params[f'layer_{layer_idx}']['attn']['kv_einsum']['w'] = jnp.ones(
(config.num_kv_heads, config.embed_dim, config.head_dim)
)

# 4. feedforward block params
params[f'layer_{layer_idx}']['mlp']['gating_einsum'] = jnp.ones(
(2, config.embed_dim, config.hidden_dim)
)
params[f'layer_{layer_idx}']['mlp']['linear'] = jnp.ones(
(config.hidden_dim, config.embed_dim)
)

# 5. layer norm params
params[f'layer_{layer_idx}']['pre_attention_norm']['scale'] = jnp.ones((
config.embed_dim,
))
params[f'layer_{layer_idx}']['pre_ffw_norm']['scale'] = jnp.ones((
config.embed_dim,
))

if config.use_post_attn_norm:
params[f'layer_{layer_idx}']['post_attn_norm']['scale'] = jnp.ones((
config.embed_dim,
))
if config.use_post_ffw_norm:
params[f'layer_{layer_idx}']['post_ffw_norm']['scale'] = jnp.ones((
config.embed_dim,
))
return res


class TransformerTest(parameterized.TestCase):

@parameterized.parameters(
Expand Down Expand Up @@ -290,6 +349,49 @@ def test_attention_types(
)
self.assertTrue(cache)

@parameterized.parameters(
dict(
config=transformer_lib.TransformerConfig(
num_layers=2,
num_embed=4,
embed_dim=2,
hidden_dim=12,
num_heads=3,
head_dim=4,
num_kv_heads=3,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL] * 2,
use_post_attn_norm=False,
use_post_ffw_norm=False,
),
),
dict(
config=transformer_lib.TransformerConfig(
num_layers=2,
num_embed=4,
embed_dim=2,
hidden_dim=12,
num_heads=3,
head_dim=4,
num_kv_heads=3,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL] * 2,
use_post_attn_norm=True,
use_post_ffw_norm=True,
),
),
)
def test_load_from_params(self, config):
params = create_fake_params(config)
transformer = transformer_lib.Transformer.from_params(params, config)
logits, _ = transformer(
last_tokens=jnp.tile(jnp.arange(3), (2, 1)),
positions=jnp.tile(jnp.arange(3), (2, 1)),
cache=None,
attention_mask=jnp.ones((2, 1, 3), dtype=jnp.bool),
)
self.assertEqual(logits.shape, (2, 3, 4))


if __name__ == '__main__':
absltest.main()

0 comments on commit 7f1567c

Please sign in to comment.