Skip to content

Commit

Permalink
fix LoRA initialization error in nnx layer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719381162
  • Loading branch information
Flax Team committed Jan 24, 2025
1 parent d28f03f commit 5d4fdd2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 41 deletions.
86 changes: 47 additions & 39 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,27 @@

import typing as tp

import jax
import jax.numpy as jnp

from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module
from flax.nnx.nn import initializers
from flax.nnx.nn.linear import Linear
from flax.typing import Dtype, Initializer
import jax
import jax.numpy as jnp

Array = jax.Array
Axis = int
Size = int
A = tp.TypeVar('A')

default_kernel_init = initializers.lecun_normal()
default_a_initializer = initializers.he_uniform()
default_b_initializer = initializers.zeros


class LoRAParam(variablelib.Param[A]):
pass



class LoRA(Module):
"""A standalone LoRA layer.
Expand Down Expand Up @@ -70,22 +69,26 @@ class LoRA(Module):
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrices.
a_initializer: initializer function for the fan-in matrices. Default to
`he_uniform`.
b_initializer: initializer function for the fan-out matrices. Default to
`zero initializer`.
lora_param_type: the type of the LoRA params.
"""

def __init__(
self,
in_features: int,
lora_rank: int,
out_features: int,
*,
base_module: tp.Optional[Module] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
kernel_init: Initializer = default_kernel_init,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
self,
in_features: int,
lora_rank: int,
out_features: int,
*,
base_module: tp.Optional[Module] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
):
self.in_features = in_features
self.out_features = out_features
Expand All @@ -95,10 +98,10 @@ def __init__(
self.base_module = base_module

self.lora_a = lora_param_type(
kernel_init(rngs.params(), (in_features, lora_rank), param_dtype)
a_initializer(rngs.params(), (in_features, lora_rank), param_dtype)
)
self.lora_b = lora_param_type(
kernel_init(rngs.params(), (lora_rank, out_features), param_dtype)
b_initializer(rngs.params(), (lora_rank, out_features), param_dtype)
)

def __call__(self, x: jax.Array):
Expand Down Expand Up @@ -142,33 +145,38 @@ class LoRALinear(Linear):
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrices.
a_initializer: initializer function for the fan-in matrices. Default to
`he_uniform`.
b_initializer: initializer function for the fan-out matrices. Default to
`zero initializer`.
lora_param_type: the type of the LoRA params.
"""

def __init__(
self,
in_features: int,
out_features: int,
*,
lora_rank: int,
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
lora_kernel_init: Initializer = default_kernel_init,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs,
self,
in_features: int,
out_features: int,
*,
lora_rank: int,
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs,
):
super().__init__(in_features, out_features, rngs=rngs, **kwargs)
self.lora = LoRA(
in_features,
lora_rank,
out_features,
dtype=lora_dtype,
param_dtype=lora_param_dtype,
kernel_init=lora_kernel_init,
lora_param_type=lora_param_type,
rngs=rngs,
in_features,
lora_rank,
out_features,
dtype=lora_dtype,
param_dtype=lora_param_dtype,
a_initializer=a_initializer,
b_initializer=b_initializer,
lora_param_type=lora_param_type,
rngs=rngs,
)

def __call__(self, x: jax.Array):
Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/nn/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(self, x):

assert y.shape == (1, 3)
assert lora_y.shape == (1, 3)
assert not jnp.allclose(y, lora_y)
np.testing.assert_allclose(y, lora_y)
a, b = model.linear2.lora_a.value, model.linear2.lora_b.value
np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y)

Expand Down Expand Up @@ -98,7 +98,7 @@ def __call__(self, x):

assert y.shape == (1, 3)
assert lora_y.shape == (1, 3)
assert not jnp.allclose(y, lora_y)
np.testing.assert_allclose(y, lora_y)
a, b = model.linear2.lora.lora_a.value, model.linear2.lora.lora_b.value
np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y)

Expand Down

0 comments on commit 5d4fdd2

Please sign in to comment.