-
Notifications
You must be signed in to change notification settings - Fork 8
/
encode_images.py
87 lines (71 loc) · 4.38 KB
/
encode_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import argparse
import pickle
from tqdm import tqdm
import PIL.Image
import numpy as np
import dnnlib.tflib as tflib
from keras.models import load_model
from keras.applications.resnet50 import preprocess_input
from encoder.generator_model import Generator
from encoder.perceptual_model import PerceptualModel, load_images
def split_to_batches(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
def main():
parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual loss')
parser.add_argument('src_dir', help='Directory with images for encoding')
parser.add_argument('generated_images_dir', help='Directory for storing generated images')
parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
parser.add_argument('--load_resnet', default='./model/finetuned_resnetOld.h5', help='resnet model dir')
# for now it's unclear if larger batch leads to better performance/quality
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
# Perceptual model params
parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
parser.add_argument('--lr', default=0.02, help='Learning rate for perceptual model', type=float)
parser.add_argument('--resnet_image_size', default=256, help='resnet image size')
parser.add_argument('--decay_steps', default=300, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
# Generator params
parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=bool)
# Perceptual model params
parser.add_argument('--iterations', default=1500, help='Number of optimization steps for each batch', type=int)
# Loss function options
parser.add_argument('--use_vgg_loss', default=4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
args, other_args = parser.parse_known_args()
ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
ref_images = list(filter(os.path.isfile, ref_images))
if len(ref_images) == 0:
raise Exception('%s is empty' % args.src_dir)
os.makedirs(args.generated_images_dir, exist_ok=True)
os.makedirs(args.dlatent_dir, exist_ok=True)
# Initialize generator and perceptual model
tflib.init_tf()
with open('./model/stylegan.pkl', 'rb') as f:
generator_network, discriminator_network, Gs_network = pickle.load(f)
generator = Generator(Gs_network, args.batch_size, randomize_noise=args.randomize_noise)
perceptual_model = PerceptualModel(args, batch_size=args.batch_size)
perceptual_model.build_perceptual_model(generator)
resnetModel = load_model(args.load_resnet)
for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size):
names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
perceptual_model.set_reference_images(images_batch)
dlatents = resnetModel.predict(preprocess_input(load_images(images_batch, img_size = args.resnet_image_size)))
generator.set_dlatents(dlatents)
op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations)
pbar = tqdm(op, leave=False, total=args.iterations)
loss = None
for loss_dict in pbar:
pbar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
generator.stochastic_clip_dlatents()
# Generate images from found dlatents and save them
generated_images = generator.generate_images()
generated_dlatents = generator.get_dlatents()
for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, names):
img = PIL.Image.fromarray(img_array, 'RGB')
img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)
generator.reset_dlatents()
if __name__ == "__main__":
main()