Training discriminator/generator pair in GAN #359
Unanswered
MandaloreUltimate
asked this question in
General
Replies: 1 comment
-
Because generator and discriminator have different loss functions we need 2 separate calls to jax.value_and_grad. The second snippet is more correct but it still has an issue: You are using optimizer.target instead of the model passed to the loss function. This way gradients aren't computed so what you need is the following:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm currently trying to build DCGAN for MNIST and got stuck on the simultaneous training of discriminator and generator as I'm really confused on what should be included in the respective loss functions. The first code snippet results in zero gradients, then I tried moving the calculation of 'generated_images', 'real_output' and 'fake_output' under the loss functions (second snippet) and it technically started training but generator was producing weird images, sort of collapsing.
The only example I could find to refer to was this pull request, but it tries to utilize single loss function and as it's stated in the comment, it's not going to work. It's also noted that there should be a single optimizer for the model, but how is it going to handle both discriminator and generator at the same time? The only way I could come up with is to train discriminator independently, fix its new weights and stack it with generator into new 'GAN' model, but that's not what I need.
Beta Was this translation helpful? Give feedback.
All reactions