-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions on Using nnx.value_and_grad
for Loss Calculation and Model Decoupling in Flax NNX
#4476
Comments
Hi @Tomato-toast.
Modules can be passed as captures as long as as no updates/mutations are performed inside. If any Module or Variable is updated inside a transform when they are passed as a capture you will get an error to avoid tracer leakage (and overall correctness).
Not sure I fully understand the question here. You mean you want to cache the value forward and pass it as an input to the losses, or you want to return the value of the forward from the losses? BTW: In the real code, |
Hello, @cgarciae thank you so much for your response,but I still have some questions. Question 1: When using
This throws the following error:
However, if I write it as:
It works fine, even though Question 2: I want to know if caching intermediate values from the forward pass and using them as inputs to the loss function would interfere with the gradient computation ( BTW: in my real code:
During the computation of
I’d greatly appreciate your clarification. Thank you again for your help! |
Hi @Tomato-toast . Question 1: This might be a bug, you should be able to get a gradient for non-Modules. I'll look into it. Can you post a minimal repro here? Question 2: yeah if you pass an intermediate from one loss function to avoid computation on a second loss function you will get a different gradient (e.g. lots of zeros) |
Hi, @cgarciae thank you very much for your answer, it helped me a lot.
|
Hello everyone.
While implementing the A2C reinforcement learning algorithm using Flax NNX, I encountered some challenges and would appreciate your guidance. Below is a simplified code:
Background
policy_network
andvalue_network
are complex models based on Transformer modules, and their forward pass involves multi-layer computational logic. These networks are implemented asnnx.Module
.advantage
, which depends on the forward pass of bothpolicy_network
andvalue_network
, I attempted to extract the forward pass as a separate step. I then passed only the results of the forward pass tocompute_policy_loss
andcompute_critic_loss
.nnx.value_and_grad
seems to require passing the models directly to the loss function. As a workaround, I passedpolicy_network
andvalue_network
intocompute_policy_loss
andcompute_critic_loss
, and the code worked. However, the twonnx.Module
instances are not explicitly used within these functions.Questions
nnx.value_and_grad
?The text was updated successfully, but these errors were encountered: