Skip to content

Commit

Permalink
Merge pull request #4504 from IvyZX:no-none-var
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720670116
  • Loading branch information
Flax Authors committed Jan 28, 2025
2 parents 00ee6f1 + ad642ef commit 881685c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def kernel_init_wrap(rng, shape, dtype):
kernel_init_wrap(rngs.params(), kernel_shape, self.param_dtype)
)

self.bias: nnx.Param[jax.Array] | None
if self.use_bias:

def bias_init_wrap(rng, shape, dtype):
Expand All @@ -226,7 +227,7 @@ def bias_init_wrap(rng, shape, dtype):
bias_init_wrap(rngs.params(), bias_shape, self.param_dtype)
)
else:
self.bias = nnx.Param(None)
self.bias = None

def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Expand All @@ -251,7 +252,7 @@ def __call__(self, inputs: Array) -> Array:
if ax not in axis
)
kernel = self.kernel.value
bias = self.bias.value
bias = self.bias.value if self.bias is not None else None

batch_ind = tuple(range(n_batch_dims))
contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
Expand Down Expand Up @@ -333,11 +334,12 @@ def __init__(
self.kernel = nnx.Param(
kernel_init(kernel_key, (in_features, out_features), param_dtype)
)
self.bias: nnx.Param[jax.Array] | None
if use_bias:
bias_key = rngs.params()
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
else:
self.bias = nnx.Param(None)
self.bias = None

self.in_features = in_features
self.out_features = out_features
Expand All @@ -359,7 +361,7 @@ def __call__(self, inputs: Array) -> Array:
The transformed input.
"""
kernel = self.kernel.value
bias = self.bias.value
bias = self.bias.value if self.bias is not None else None

inputs, kernel, bias = dtypes.promote_dtype(
(inputs, kernel, bias), dtype=self.dtype
Expand Down Expand Up @@ -644,12 +646,13 @@ def __init__(
self.kernel_shape = kernel_shape
self.kernel = nnx.Param(kernel_init(kernel_key, kernel_shape, param_dtype))

self.bias: nnx.Param[jax.Array] | None
if use_bias:
bias_shape = (out_features,)
bias_key = rngs.params()
self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype))
else:
self.bias = nnx.Param(None)
self.bias = None

self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -755,7 +758,7 @@ def maybe_broadcast(
if self.mask is not None:
kernel *= self.mask

bias = self.bias.value
bias = self.bias.value if self.bias is not None else None

inputs, kernel, bias = dtypes.promote_dtype(
(inputs, kernel, bias), dtype=self.dtype
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/nn/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_lora_base_module(self):
assert y.shape == (1, 4)
assert module.base_module == linear
assert module.base_module.kernel.value.shape == (3, 4)
assert module.base_module.bias.value == None
assert module.base_module.bias == None
assert module.lora_a.value.shape == (3, 2)
assert module.lora_b.value.shape == (2, 4)
np.testing.assert_allclose(
Expand Down

0 comments on commit 881685c

Please sign in to comment.