From ad642ef6ca52e428b5d22ae7c7af2ca305d4b293 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Fri, 24 Jan 2025 18:27:38 -0800 Subject: [PATCH] Remove all Param(None) lines --- flax/nnx/nn/linear.py | 15 +++++++++------ tests/nnx/nn/lora_test.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 71bf9313f..e6cd308bd 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -210,6 +210,7 @@ def kernel_init_wrap(rng, shape, dtype): kernel_init_wrap(rngs.params(), kernel_shape, self.param_dtype) ) + self.bias: nnx.Param[jax.Array] | None if self.use_bias: def bias_init_wrap(rng, shape, dtype): @@ -226,7 +227,7 @@ def bias_init_wrap(rng, shape, dtype): bias_init_wrap(rngs.params(), bias_shape, self.param_dtype) ) else: - self.bias = nnx.Param(None) + self.bias = None def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. @@ -251,7 +252,7 @@ def __call__(self, inputs: Array) -> Array: if ax not in axis ) kernel = self.kernel.value - bias = self.bias.value + bias = self.bias.value if self.bias is not None else None batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) @@ -333,11 +334,12 @@ def __init__( self.kernel = nnx.Param( kernel_init(kernel_key, (in_features, out_features), param_dtype) ) + self.bias: nnx.Param[jax.Array] | None if use_bias: bias_key = rngs.params() self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype)) else: - self.bias = nnx.Param(None) + self.bias = None self.in_features = in_features self.out_features = out_features @@ -359,7 +361,7 @@ def __call__(self, inputs: Array) -> Array: The transformed input. """ kernel = self.kernel.value - bias = self.bias.value + bias = self.bias.value if self.bias is not None else None inputs, kernel, bias = dtypes.promote_dtype( (inputs, kernel, bias), dtype=self.dtype @@ -644,12 +646,13 @@ def __init__( self.kernel_shape = kernel_shape self.kernel = nnx.Param(kernel_init(kernel_key, kernel_shape, param_dtype)) + self.bias: nnx.Param[jax.Array] | None if use_bias: bias_shape = (out_features,) bias_key = rngs.params() self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype)) else: - self.bias = nnx.Param(None) + self.bias = None self.in_features = in_features self.out_features = out_features @@ -755,7 +758,7 @@ def maybe_broadcast( if self.mask is not None: kernel *= self.mask - bias = self.bias.value + bias = self.bias.value if self.bias is not None else None inputs, kernel, bias = dtypes.promote_dtype( (inputs, kernel, bias), dtype=self.dtype diff --git a/tests/nnx/nn/lora_test.py b/tests/nnx/nn/lora_test.py index 0fec6a05f..02a7bffb7 100644 --- a/tests/nnx/nn/lora_test.py +++ b/tests/nnx/nn/lora_test.py @@ -41,7 +41,7 @@ def test_lora_base_module(self): assert y.shape == (1, 4) assert module.base_module == linear assert module.base_module.kernel.value.shape == (3, 4) - assert module.base_module.bias.value == None + assert module.base_module.bias == None assert module.lora_a.value.shape == (3, 2) assert module.lora_b.value.shape == (2, 4) np.testing.assert_allclose(