-
Hey all! I'm trying to implement a Wasserstein GAN (WGAN) with penalty, and for this penalty i need the gradient activations of the output with respect to the input. In this specific case, the critic in the WGAN receives as input an image, and assigns a score to it. I know how to do this in pytorch, and it would work along these lines given a single input image:
I tried using jax.jacrev (thinking this would create the jacobian and i could simply do jac_result @ jnp.ones((1,)), but this operation (jax.jacrev/jax.jacfwd) crashes colab. Any help would be very much appreciated! If I should instead post this to the jax discussion board, please do tell. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
See for example #579 It should look something like this:
|
Beta Was this translation helpful? Give feedback.
See for example #579
It should look something like this: