Skip to content

How to set up the same initializer as PyTorch for Conv layers? #4466

Answered by cgarciae
ZedongPeng asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @ZedongPeng, matching the exact initialization values is a bit out of the scope since JAX and Pytorch use different RNG mechanisms. That said, you can copy over the values from an equivalent pytorch implementation like this:

pt_state_dict = pytorch_model.state_dict()
nnx_variables = nnx.variables(nnx_model)

# copy over values
for path, variable in nnx_variables.flat_state():
  variable.value = ... # condition on `path` to grab the correct value from `pt_state_dict`

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ZedongPeng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants