Skip to content

Commit

Permalink
fixed bug with disc losses printing
Browse files Browse the repository at this point in the history
  • Loading branch information
ysaatchi committed Jul 30, 2018
1 parent 1868041 commit ca06613
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions run_bgan_semi.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def b_dcgan(dataset, args):

session = get_session()
tf.set_random_seed(args.random_seed)
# due to how much the TF code sucks all functions take fixed batch_size at all times

dcgan = BDCGAN_Semi(x_dim, z_dim, dataset_size, batch_size=batch_size, J=args.J, J_d=args.J_d, M=args.M,
num_layers=args.num_layers,
lr=args.lr, optimizer=args.optimizer, gf_dim=args.gf_dim,
Expand Down Expand Up @@ -179,16 +179,13 @@ def b_dcgan(dataset, args):

### compute disc losses
batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
disc_info = session.run(optimizer_dict["disc_semi"] + dcgan.d_losses, # + [dcgan.d_probs] + [dcgan.d_hh],
disc_info = session.run(optimizer_dict["disc_semi"] + dcgan.d_losses,
feed_dict={dcgan.labeled_inputs: labeled_image_batch,
dcgan.labels: labels,
dcgan.inputs: image_batch,
dcgan.z: batch_z,
dcgan.d_semi_learning_rate: learning_rate})

d_losses = disc_info[num_disc:num_disc*2]

#print disc_info[num_disc*2:num_disc*3][0][:, 0]
d_losses = [d_ for d_ in disc_info if d_ is not None]

### compute generative losses
batch_z = np.random.uniform(-1, 1, [batch_size, z_dim, dcgan.num_gen])
Expand Down

0 comments on commit ca06613

Please sign in to comment.