diff --git a/1. Vanilla GAN PyTorch.ipynb b/1. Vanilla GAN PyTorch.ipynb index e1fe594..3182096 100644 --- a/1. Vanilla GAN PyTorch.ipynb +++ b/1. Vanilla GAN PyTorch.ipynb @@ -218,7 +218,14 @@ "g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)\n", "\n", "# Loss function\n", - "loss = nn.BCELoss()\n", + "class StableBCELoss(nn.modules.Module):", + " def __init__(self):", + " super(StableBCELoss, self).__init__()", + " def forward(self, input, target):", + " neg_abs = - input.abs()", + " loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()", + " return loss.mean()", + "loss = StableBCELoss()\n", "\n", "# Number of steps to apply to the discriminator\n", "d_steps = 1 # In Goodfellow et. al 2014 this variable is assigned to 1\n",