diff --git a/implementations/began/began.py b/implementations/began/began.py index d9d420d8..d6243217 100644 --- a/implementations/began/began.py +++ b/implementations/began/began.py @@ -2,6 +2,7 @@ import os import numpy as np import math +import time import torchvision.transforms as transforms from torchvision.utils import save_image @@ -27,12 +28,17 @@ parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels") +parser.add_argument('--inference', action='store_true', default=False) +parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"') +parser.add_argument('--channels_last', type=int, default=1, help='use channels last format') +parser.add_argument('--num-iterations', default=100, type=int) opt = parser.parse_args() print(opt) img_shape = (opt.channels, opt.img_size, opt.img_size) cuda = True if torch.cuda.is_available() else False +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor def weights_init_normal(m): @@ -68,6 +74,8 @@ def __init__(self): def forward(self, noise): out = self.l1(noise) out = out.view(out.shape[0], 128, self.init_size, self.init_size) + if opt.channels_last: + out = out.to(memory_format=torch.channels_last) img = self.conv_blocks(out) return img @@ -94,116 +102,167 @@ def __init__(self): def forward(self, img): out = self.down(img) + if opt.channels_last: + out = out.contiguous() out = self.fc(out.view(out.size(0), -1)) - out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size)) + out = out.view(out.size(0), 64, self.down_size, self.down_size) + if opt.channels_last: + out = out.to(memory_format=torch.channels_last) + out = self.up(out) return out +def main(): + # Initialize generator and discriminator + generator = Generator() + discriminator = Discriminator() + + if cuda: + generator.cuda() + discriminator.cuda() + else: + generator.cpu() + discriminator.cpu() + + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + device = torch.device('cuda') if cuda else torch.device('cpu') + if opt.inference: + print("----------------Generation---------------") + if opt.precision == "bfloat16": + cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast + with cm(): + generate(generator, device=device) + else: + generate(generator, device=device) + else: + print("-------------------Train-----------------") + train(generator, discriminator) + + +def generate(netG, device): + fixed_noise = Variable(Tensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim)))) + if opt.channels_last: + netG_oob = netG + try: + netG_oob = netG_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC model") + except: + print("[WARN] Input NHWC failed! Use normal model") + netG = netG_oob + else: + fixed_noise = fixed_noise.to(device=device) + netG.eval() + + total_iters = opt.num_iterations + with torch.no_grad(): + tic = time.time() + for i in range(total_iters): + fake = netG(fixed_noise) + toc = time.time() - tic + print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations)) -# Initialize generator and discriminator -generator = Generator() -discriminator = Discriminator() - -if cuda: - generator.cuda() - discriminator.cuda() - -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - -# Configure data loader -os.makedirs("../../data/mnist", exist_ok=True) -dataloader = torch.utils.data.DataLoader( - datasets.MNIST( - "../../data/mnist", - train=True, - download=True, - transform=transforms.Compose( - [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] - ), - ), - batch_size=opt.batch_size, - shuffle=True, -) - -# Optimizers -optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) -optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) - -Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- -# BEGAN hyper parameters -gamma = 0.75 -lambda_k = 0.001 -k = 0.0 - -for epoch in range(opt.n_epochs): - for i, (imgs, _) in enumerate(dataloader): - - # Configure input - real_imgs = Variable(imgs.type(Tensor)) - - # ----------------- - # Train Generator - # ----------------- - - optimizer_G.zero_grad() - - # Sample noise as generator input - z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) - - # Generate a batch of images - gen_imgs = generator(z) - - # Loss measures generator's ability to fool the discriminator - g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs)) - - g_loss.backward() - optimizer_G.step() - - # --------------------- - # Train Discriminator - # --------------------- - - optimizer_D.zero_grad() - - # Measure discriminator's ability to classify real from generated samples - d_real = discriminator(real_imgs) - d_fake = discriminator(gen_imgs.detach()) - - d_loss_real = torch.mean(torch.abs(d_real - real_imgs)) - d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach())) - d_loss = d_loss_real - k * d_loss_fake - - d_loss.backward() - optimizer_D.step() - - # ---------------- - # Update weights - # ---------------- - - diff = torch.mean(gamma * d_loss_real - d_loss_fake) - - # Update weight term for fake samples - k = k + lambda_k * diff.item() - k = min(max(k, 0), 1) # Constraint to interval [0, 1] - - # Update convergence metric - M = (d_loss_real + torch.abs(diff)).data[0] - - # -------------- - # Log Progress - # -------------- - - print( - "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f" - % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k) - ) - - batches_done = epoch * len(dataloader) + i - if batches_done % opt.sample_interval == 0: - save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) +def train(netG, netD): + # BEGAN hyper parameters + gamma = 0.75 + lambda_k = 0.001 + k = 0.0 + + # Configure data loader + os.makedirs("../../data/mnist", exist_ok=True) + dataloader = torch.utils.data.DataLoader( + datasets.MNIST( + "../../data/mnist", + train=True, + download=True, + transform=transforms.Compose( + [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + ), + ), + batch_size=opt.batch_size, + shuffle=True, + ) + # Optimizers + optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + + for epoch in range(opt.n_epochs): + for i, (imgs, _) in enumerate(dataloader): + if opt.channels_last: + imgs_oob = imgs + try: + imgs_oob = imgs_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC input") + except: + print("[WARN] Input NHWC failed! Use normal input") + imgs = imgs_oob + # Configure input + real_imgs = Variable(imgs.type(Tensor)) + + # ----------------- + # Train Generator + # ----------------- + + optimizer_G.zero_grad() + + # Sample noise as generator input + z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) + + # Generate a batch of images + gen_imgs = netG(z) + + # Loss measures generator's ability to fool the discriminator + g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs)) + + g_loss.backward() + optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + # Measure discriminator's ability to classify real from generated samples + d_real = netD(real_imgs) + d_fake = netD(gen_imgs.detach()) + + d_loss_real = torch.mean(torch.abs(d_real - real_imgs)) + d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach())) + d_loss = d_loss_real - k * d_loss_fake + + d_loss.backward() + optimizer_D.step() + + # ---------------- + # Update weights + # ---------------- + + diff = torch.mean(gamma * d_loss_real - d_loss_fake) + + # Update weight term for fake samples + k = k + lambda_k * diff.item() + k = min(max(k, 0), 1) # Constraint to interval [0, 1] + + # Update convergence metric + M = (d_loss_real + torch.abs(diff)).data.item() + + # -------------- + # Log Progress + # -------------- + + print( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f" + % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k) + ) + + batches_done = epoch * len(dataloader) + i + if batches_done % opt.sample_interval == 0: + save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) +if __name__ == '__main__': + main() diff --git a/implementations/sgan/sgan.py b/implementations/sgan/sgan.py index d12bf292..d7d77e1f 100644 --- a/implementations/sgan/sgan.py +++ b/implementations/sgan/sgan.py @@ -2,6 +2,7 @@ import os import numpy as np import math +import time import torchvision.transforms as transforms from torchvision.utils import save_image @@ -28,10 +29,16 @@ parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling") +parser.add_argument('--inference', action='store_true', default=False) +parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"') +parser.add_argument('--channels_last', type=int, default=1, help='use channels last format') +parser.add_argument('--num-iterations', default=100, type=int) opt = parser.parse_args() print(opt) cuda = True if torch.cuda.is_available() else False +FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor +LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor def weights_init_normal(m): @@ -69,6 +76,8 @@ def __init__(self): def forward(self, noise): out = self.l1(noise) out = out.view(out.shape[0], 128, self.init_size, self.init_size) + if opt.channels_last: + out = out.to(memory_format=torch.channels_last) img = self.conv_blocks(out) return img @@ -100,6 +109,8 @@ def discriminator_block(in_filters, out_filters, bn=True): def forward(self, img): out = self.conv_blocks(img) + if opt.channels_last: + out = out.contiguous() out = out.view(out.shape[0], -1) validity = self.adv_layer(out) label = self.aux_layer(out) @@ -111,109 +122,161 @@ def forward(self, img): adversarial_loss = torch.nn.BCELoss() auxiliary_loss = torch.nn.CrossEntropyLoss() -# Initialize generator and discriminator -generator = Generator() -discriminator = Discriminator() - -if cuda: - generator.cuda() - discriminator.cuda() - adversarial_loss.cuda() - auxiliary_loss.cuda() - -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - -# Configure data loader -os.makedirs("../../data/mnist", exist_ok=True) -dataloader = torch.utils.data.DataLoader( - datasets.MNIST( - "../../data/mnist", - train=True, - download=True, - transform=transforms.Compose( - [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] - ), - ), - batch_size=opt.batch_size, - shuffle=True, -) - -# Optimizers -optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) -optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) +def main(): + # Initialize generator and discriminator + generator = Generator() + discriminator = Discriminator() + + if cuda: + generator.cuda() + discriminator.cuda() + adversarial_loss.cuda() + auxiliary_loss.cuda() + else: + generator.cpu() + discriminator.cpu() + adversarial_loss.cpu() + auxiliary_loss.cpu() + + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + device = torch.device('cuda') if cuda else torch.device('cpu') + if opt.inference: + print("----------------Generation---------------") + if opt.precision == "bfloat16": + cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast + with cm(): + generate(generator, device=device) + else: + generate(generator, device=device) + else: + print("-------------------Train-----------------") + train(generator, discriminator) + + +def generate(netG, device): + fixed_noise = Variable(FloatTensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim)))) + if opt.channels_last: + netG_oob = netG + try: + netG_oob = netG_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC model") + except: + print("[WARN] Input NHWC failed! Use normal model") + netG = netG_oob + else: + fixed_noise = fixed_noise.to(device=device) + netG.eval() + + total_iters = opt.num_iterations + with torch.no_grad(): + for i in range(total_iters): + tic = time.time() + fake = netG(fixed_noise) + toc = time.time() - tic + print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations)) -FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor -LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor # ---------- # Training # ---------- -for epoch in range(opt.n_epochs): - for i, (imgs, labels) in enumerate(dataloader): +def train(netG, netD): + + # Configure data loader + os.makedirs("../../data/mnist", exist_ok=True) + dataloader = torch.utils.data.DataLoader( + datasets.MNIST( + "../../data/mnist", + train=True, + download=True, + transform=transforms.Compose( + [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + ), + ), + batch_size=opt.batch_size, + shuffle=True, + ) + # Optimizers + optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) - batch_size = imgs.shape[0] - # Adversarial ground truths - valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) - fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) - fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False) + for epoch in range(opt.n_epochs): + for i, (imgs, labels) in enumerate(dataloader): + if opt.channels_last: + imgs_oob = imgs + try: + imgs_oob = imgs_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC input") + except: + print("[WARN] Input NHWC failed! Use normal input") + imgs = imgs_oob - # Configure input - real_imgs = Variable(imgs.type(FloatTensor)) - labels = Variable(labels.type(LongTensor)) + batch_size = imgs.shape[0] - # ----------------- - # Train Generator - # ----------------- + # Adversarial ground truths + valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) + fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) + fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False) - optimizer_G.zero_grad() + # Configure input + real_imgs = Variable(imgs.type(FloatTensor)) + labels = Variable(labels.type(LongTensor)) - # Sample noise and labels as generator input - z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))) + # ----------------- + # Train Generator + # ----------------- - # Generate a batch of images - gen_imgs = generator(z) + optimizer_G.zero_grad() - # Loss measures generator's ability to fool the discriminator - validity, _ = discriminator(gen_imgs) - g_loss = adversarial_loss(validity, valid) + # Sample noise and labels as generator input + z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))) - g_loss.backward() - optimizer_G.step() + # Generate a batch of images + gen_imgs = netG(z) - # --------------------- - # Train Discriminator - # --------------------- + # Loss measures generator's ability to fool the discriminator + validity, _ = netD(gen_imgs) + g_loss = adversarial_loss(validity, valid) - optimizer_D.zero_grad() + g_loss.backward() + optimizer_G.step() - # Loss for real images - real_pred, real_aux = discriminator(real_imgs) - d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 + # --------------------- + # Train Discriminator + # --------------------- - # Loss for fake images - fake_pred, fake_aux = discriminator(gen_imgs.detach()) - d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2 + optimizer_D.zero_grad() - # Total discriminator loss - d_loss = (d_real_loss + d_fake_loss) / 2 + # Loss for real images + real_pred, real_aux = netD(real_imgs) + d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 - # Calculate discriminator accuracy - pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) - gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0) - d_acc = np.mean(np.argmax(pred, axis=1) == gt) + # Loss for fake images + fake_pred, fake_aux = netD(gen_imgs.detach()) + d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2 - d_loss.backward() - optimizer_D.step() + # Total discriminator loss + d_loss = (d_real_loss + d_fake_loss) / 2 - print( - "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" - % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()) - ) + # Calculate discriminator accuracy + pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) + gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0) + d_acc = np.mean(np.argmax(pred, axis=1) == gt) + + d_loss.backward() + optimizer_D.step() + + print( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" + % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()) + ) + + batches_done = epoch * len(dataloader) + i + if batches_done % opt.sample_interval == 0: + save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) - batches_done = epoch * len(dataloader) + i - if batches_done % opt.sample_interval == 0: - save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) +if __name__ == '__main__': + main()