-
Notifications
You must be signed in to change notification settings - Fork 3
/
loss.py
32 lines (25 loc) · 1.2 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
def gan_loss(logits_real, logits_fake):
"""Compute the GAN loss.
Inputs:
- logits_real: Tensor, shape [batch_size, 1], output of discriminator
Log probability that the image is real for each real image
- logits_fake: Tensor, shape[batch_size, 1], output of discriminator
Log probability that the image is real for each fake image
Returns:
- D_loss: discriminator loss scalar
- G_loss: generator loss scalar
"""
with tf.variable_scope("G_loss_gan"):
labels_ones_f = tf.ones_like(logits_fake)
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_ones_f, logits=logits_fake))
with tf.variable_scope("D_loss_gan"):
labels_zeros_f = tf.zeros_like(logits_fake)
labels_ones_r = tf.ones_like(logits_real)
D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_ones_r, logits=logits_real))
D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_zeros_f, logits=logits_fake))
return D_loss, G_loss
def l1_loss(real, fake):
with tf.variable_scope("L1_loss"):
loss = tf.reduce_mean(tf.abs(real - fake))
return loss