How to set up the same initializer as PyTorch for Conv layers? #4466
-
I am using I understand that these discrepancies might stem from differences in the implementations between That said, I’m wondering if there’s a way to make the initial weights identical to Ref:
import torch
from jax import random
from flax import nnx
import jax.numpy as jnp
import numpy as np
t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid', bias=False)
# torch.nn.init.kaiming_normal_(t_conv.weight)
# torch.nn.init.kaiming_uniform_(t_conv.weight)
kernel = t_conv.weight.detach().cpu().numpy()
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
j_conv = nnx.Conv(
in_features=3,
out_features=4,
kernel_size=(2, 2),
padding='valid',
use_bias=False,
rngs=nnx.Rngs(0),
# kernel_init=nnx.initializers.kaiming_normal(),
# kernel_init=nnx.initializers.kaiming_uniform(),
)
# Uncomment this line to port initial weights from pytorch.
# j_conv.kernel = nnx.Param(kernel)
j_out = j_conv(x)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
print("j_out", j_out)
print("t_out", t_out)
np.testing.assert_almost_equal(j_out, t_out, decimal=6) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi @ZedongPeng, matching the exact initialization values is a bit out of the scope since JAX and Pytorch use different RNG mechanisms. That said, you can copy over the values from an equivalent pytorch implementation like this: pt_state_dict = pytorch_model.state_dict()
nnx_variables = nnx.variables(nnx_model)
# copy over values
for path, variable in nnx_variables.flat_state():
variable.value = ... # condition on `path` to grab the correct value from `pt_state_dict` |
Beta Was this translation helpful? Give feedback.
Hi @ZedongPeng, matching the exact initialization values is a bit out of the scope since JAX and Pytorch use different RNG mechanisms. That said, you can copy over the values from an equivalent pytorch implementation like this: