-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_back_to_z_es_cma_bce.py
74 lines (72 loc) · 3.37 KB
/
test_back_to_z_es_cma_bce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
import keras
from keras import backend as K
K.set_session(session)
from keras.models import *
from models import up_bilinear
from cma import CMAEvolutionStrategy as ES
from skimage.io import imsave, imread
from skimage.transform import resize
from pixel_shuffler import PixelShuffler
import argparse
parser = argparse.ArgumentParser(description='Image Generation with GAN')
parser.add_argument('input', metavar='input', type=str, help='')
parser.add_argument('output', metavar='output', type=str, help='')
parser.add_argument('--decoder', type=str, default='./decoder.h5', required=False, help='decoder')
parser.add_argument('--encoder', type=str, default='./encoder.h5', required=False, help='encoder')
parser.add_argument('--std', type=float, default=0.1, required=False, help='')
parser.add_argument('--sigma', type=float, default=1.0, required=False, help='')
parser.add_argument('--iterations', type=int, default=500, required=False, help='')
parser.add_argument('--populations', type=int, default=500, required=False, help='')
parser.add_argument('--offsprings', type=int, default=200, required=False, help='')
parser.add_argument('--runs', type=int, default=10, required=False, help='')
args = parser.parse_args()
decoder = load_model(args.decoder, custom_objects={'tf':tf, 'PixelShuffler':PixelShuffler, 'up_bilinear':up_bilinear})
encoder = load_model(args.encoder, custom_objects={'tf':tf, 'PixelShuffler':PixelShuffler, 'up_bilinear':up_bilinear}) if os.path.exists(args.encoder) else None
img = (resize(imread(args.input), decoder.output_shape[-3:-1], preserve_range=True).astype(np.float32) - 127.5) / 127.5
if img.ndim==2:
img = img[...,np.newaxis]
def Fitness(img, decoder):
def bce(x):
y_t, y_p = np.clip(img[np.newaxis,...]*.5+.5,1e-8,1-1e-8), np.clip(decoder.predict(x)*.5+.5, 1e-8, 1-1e-8)
return -np.mean(( y_t*np.log(y_p) + (1-y_t)*np.log(1-y_p) ).reshape(x.shape[0], -1), axis=-1)
return bce
if not encoder is None:
x_mean = encoder.predict(img[np.newaxis,...])
fitness_func = Fitness(img, decoder)
best_img = None
best_z = None
best_score = -1
for i in range(args.runs):
print('Runs: %d / %d'%(i+1, args.runs))
if encoder is None:
init = np.random.randn(decoder.input_shape[-1]) * args.std
else:
init = x_mean[0]
es = ES(init, args.sigma)
for ite in range(args.iterations):
dnas = np.asarray(es.ask())
es.tell(dnas, fitness_func(dnas))
es.disp()
es.result_pretty()
z = np.asarray(es.result[0])
img_reconstruct = decoder.predict(z[np.newaxis,...])[0]
y_t, y_p = np.clip(img*.5+.5,1e-8,1-1e-8), np.clip(img_reconstruct*.5+.5,1e-8,1-1e-8)
bce = -np.mean( y_t*np.log(y_p) + (1-y_t)*np.log(1-y_p) )
print('bce: {:.2f}'.format(bce))
if bce>best_score:
best_score = bce
best_z = z
best_img = img_reconstruct
output_img = np.round(np.concatenate((np.squeeze(img), np.squeeze(img_reconstruct)), axis=1) * 127.5 + 127.5).astype(np.uint8)
filename, ext = os.path.splitext(args.output)
imsave(filename+'_%d'%i+ext, output_img)
output_img = np.round(np.concatenate((np.squeeze(img), np.squeeze(best_img)), axis=1) * 127.5 + 127.5).astype(np.uint8)
imsave(args.output, output_img)