Apply Gradient Not Behaving appropriately #1030
Unanswered
vasilavramov
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hi @vasilavramov, I think you'll have to ask your question more concretely and/or share code to get more help. My first guess is that you're doing something like |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I am currently trying to train a NN through reinforcement learning and I am struggling to update the NN parameters once an epoch of training is complete. The gradients are computed as follows:
`def compute_gradients_EV(optimizer, input, params):
def EV_short(params):
return policy_EV(input, params)
policy_fit, grad = jax.value_and_grad(EV_short, has_aux = False)(params)
return policy_fit, grad
Error unsupported operand type(s) for *: 'float' and 'FrozenDict'`
The error I get makes no sense given that in the base documentation of the optim class it says that apply gradients should work with a pytree of gradients. I have tried a different method, which is similar to the one showed in examples where the jax.value_and_grad function is called with the model and the using optimiser.target to differentiate. However, in that case I get the error that the model with which I am trying to call the differentiating function 'is not a valid Jax type'.
From what I understand I have successfully computed the gradients, however updating them seems impossible using optimizer.apply_gradients(), but it should be. Anyway, any help would be much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions