-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with Optimizer Update in A2C Network with Optax Body: #4391
Comments
Hi @Tomato-toast, can you post some psuedo code of how you are constructing the Optimizer and gradients? |
Below is a pseudo-code example of how the Optimizer and gradients are constructed and applied:
Thanks! |
I would just offer my input here and some suggestions based on my relatively short experience with NNX. I noticed you are using the Ok, to the matter at hand, the problem here is with the You have: def a2c_loss(self, policy_network, params, observations, actions, returns): When you transform with flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=()) You see by default grad_fn = nnx.value_and_grad(self.a2c_loss, has_aux=True) Thus you are taking the derivative with respect to grad_fn = nnx.value_and_grad(self.a2c_loss, argnums=1, has_aux=True) This will take the derivative with respect to On a more general note, you can simplify your code significantly. For example, |
@Tomato-toast how are the policy_output = policy_network(params.actor, observations)
critic_output = policy_network(params.critic, observations) They seem to be Modules that take in their Since you are using a functional style training loop, I'd recommend to storing the Regarding the |
I would like to extend my heartfelt gratitude for your previous assistance and suggestions. Following your guidance, I have attempted to modify the code from
to
However, I have encountered a new error with the following message:
This indicates that the issue may not be related to the Moreover, I have taken your previous advice regarding code optimization under serious consideration and will incorporate it into my future work plans. Thank you once again for your valuable insights and support. |
Thank you very much for your valuable suggestion! Following your guidance, I referred to the example in examples/nnx_toy_examples/03_train_state.py and attempted to use the Here are the details of my attempt and the error message: Example code
Error message
At this point, I’m unsure whether the issue could be related to: The way graphdef is defined—does it require specific attention? |
Hello everyone,
I've encountered a problem while implementing an A2C (Advantage Actor-Critic) network involving Flax and Optax. My network includes policy_network and value_network, each containing policy_head and torso. When attempting to use optimizer.update(grad), I received the following error:
ValueError: Mismatch custom node data: ('policy_head', 'torso') != ('policy_network', 'value_network');
The error message indicates that the expected keys are ('policy_network', 'value_network'), but the actual provided keys are ('policy_head', 'torso'). The structure of my model parameters is as follows:
State({
'policy_network': {
'policy_head': {...},
'torso': {...},
},
'value_network': {
'policy_head': {...},
'torso': {...},
})
I have tried to combine the model parameters and pass them to the optimizer, like this:
params = {'w1': model1_params, 'w2': model2_params}
However, this approach did not resolve the issue. I'm wondering if there is another way to correctly initialize and update the parameters of the A2C network's parameters using Optax in Flax.
If you have any suggestions or need more information, please let me know. Thank you very much for your help!
The text was updated successfully, but these errors were encountered: