diff --git a/README.rst b/README.rst index 0bf99e7..1d75292 100644 --- a/README.rst +++ b/README.rst @@ -94,7 +94,7 @@ Pre-trained models are provided in the GitHub releases. Training your own is a The easiest way to get up-and-running is to `install Docker `_. Then, you should be able to download and run the pre-built image using the ``docker`` command line tool. Find out more about the ``alexjc/neural-enhance`` image on its `Docker Hub `_ page. -Here's the simplest way you can call the script using ``docker``, assuming you're familiar with using ``-v`` argument to mount folders you can use this directly to specify files to enhance: +Here's the simplest way you can call the script using ``docker``, assuming you're familiar with using ``-v`` argument to mount folders (see `documentation `_) you can use this directly to specify files to enhance: .. code:: bash @@ -161,7 +161,9 @@ This code uses a combination of techniques from the following papers, as well as Special thanks for their help and support in various ways: +* Roelof Pieters — Provided a rack of TitanX GPUs for training model variations on OpenImages dataset. * Eder Santana — Discussions, encouragement, and his ideas on `sub-pixel deconvolution `_. +* Wenzhe Shi — Practical advice and feedback on training procedures for the super-resolution GAN [4]. * Andrew Brock — This sub-pixel layer code is based on `his project repository `_ using Lasagne. * Casper Kaae Sønderby — For suggesting a more stable alternative to sigmoid + log as GAN loss functions. diff --git a/enhance.py b/enhance.py index 5d704d3..4e9c75e 100755 --- a/enhance.py +++ b/enhance.py @@ -14,7 +14,7 @@ # without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # -__version__ = '0.3' +__version__ = '0.4' import io import os @@ -32,7 +32,7 @@ # Configure all options first so we can later custom-load other libraries (Theano) based on device specified by user. -parser = argparse.ArgumentParser(description='Generate a new image by applying style onto a content image.', +parser = argparse.ArgumentParser(description='Enhance a low-res image into high-def using neural networks.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) add_arg = parser.add_argument add_arg('files', nargs='*', default=[]) @@ -43,10 +43,11 @@ add_arg('--type', default='photo', type=str, help='Name of the neural network to load/save.') add_arg('--model', default='default', type=str, help='Specific trained version of the model.') add_arg('--train', default=False, type=str, help='File pattern to load for training.') -add_arg('--train-scales', default=0, type=int, help='Randomly resize images this many times.') -add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.') -add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.') -add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level & range in preproc.') +add_arg('--train-scales', default=[0], nargs='+', type=int, help='Randomly resize images, specify min/max.') +add_arg('--train-blur', default=[], nargs='+', type=int, help='Sigma value for gaussian blur, min/max.') +add_arg('--train-noise', default=None, type=float, help='Distribution for gaussian noise preprocess.') +add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level, specify min/max.') +add_arg('--train-plugin', default=None, type=str, help='Filename for python pre-processing script.') add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.') add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.') @@ -139,14 +140,24 @@ def __init__(self): self.data_ready = threading.Event() self.data_copied = threading.Event() + if args.train_plugin is not None: + import importlib.util + spec = importlib.util.spec_from_file_location('enhance.plugin', 'plugins/{}.py'.format(args.train_plugin)) + plugin = importlib.util.module_from_spec(spec) + spec.loader.exec_module(plugin) + + self.iterate_files = plugin.iterate_files + self.load_original = plugin.load_original + self.load_seed = plugin.load_seed + self.orig_shape, self.seed_shape = args.batch_shape, args.batch_shape // args.zoom self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32) self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32) self.files = glob.glob(args.train) if len(self.files) == 0: - error("There were no files found to train from searching for `{}`".format(args.train), - " - Try putting all your images in one folder and using `--train=data/*.jpg`") + error('There were no files found to train from searching for `{}`'.format(args.train), + ' - Try putting all your images in one folder and using `--train="data/*.jpg"`') self.available = set(range(args.buffer_size)) self.ready = set() @@ -154,43 +165,58 @@ def __init__(self): self.cwd = os.getcwd() self.start() - def run(self): + def iterate_files(self): while True: random.shuffle(self.files) for f in self.files: - self.add_to_buffer(f) + yield f - def add_to_buffer(self, f): - filename = os.path.join(self.cwd, f) + def load_original(self, filename): try: orig = PIL.Image.open(filename).convert('RGB') - scale = 2 ** random.randint(0, args.train_scales) + scale = 2 ** random.randint(args.train_scales[0], args.train_scales[-1]) if scale > 1 and all(s//scale >= args.batch_shape for s in orig.size): - orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=PIL.Image.LANCZOS) + orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=random.randint(0,3)) if any(s < args.batch_shape for s in orig.size): raise ValueError('Image is too small for training with size {}'.format(orig.size)) + return scipy.misc.fromimage(orig).astype(np.float32) except Exception as e: warn('Could not load `{}` as image.'.format(filename), ' - Try fixing or removing the file before next run.') - self.files.remove(f) - return + self.files.remove(filename) + return None - seed = orig - if args.train_blur is not None: - seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2))) + def load_seed(self, filename, original, zoom): + seed = scipy.misc.toimage(original) + if len(args.train_blur): + seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(args.train_blur[0], args.train_blur[-1]))) if args.zoom > 1: - seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS) + seed = seed.resize((seed.size[0]//zoom, seed.size[1]//zoom), resample=random.randint(0,3)) + if len(args.train_jpeg) > 0: - buffer, rng = io.BytesIO(), args.train_jpeg[-1] if len(args.train_jpeg) > 1 else 15 - seed.save(buffer, format='jpeg', quality=args.train_jpeg[0]+random.randrange(-rng, +rng)) + buffer = io.BytesIO() + seed.save(buffer, format='jpeg', quality=random.randrange(args.train_jpeg[0], args.train_jpeg[-1])) seed = PIL.Image.open(buffer) - orig = scipy.misc.fromimage(orig).astype(np.float32) seed = scipy.misc.fromimage(seed).astype(np.float32) - if args.train_noise is not None: seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1)) + return seed + def run(self): + for filename in self.iterate_files(): + f = os.path.join(self.cwd, filename) + orig = self.load_original(f) + if orig is None: continue + + seed = self.load_seed(f, orig, args.zoom) + if seed is None: continue + + self.enqueue(orig, seed) + + raise ValueError('Insufficient number of files found for training.') + + def enqueue(self, orig, seed): for _ in range(seed.shape[0] * seed.shape[1] // (args.buffer_fraction * self.seed_shape ** 2)): h = random.randint(0, seed.shape[0] - self.seed_shape) w = random.randint(0, seed.shape[1] - self.seed_shape) @@ -241,7 +267,34 @@ def up(d): return self.upscale * d if d else d def get_output_for(self, input, deterministic=False, **kwargs): out, r = T.zeros(self.get_output_shape_for(input.shape)), self.upscale for y, x in itertools.product(range(r), repeat=2): - out=T.inc_subtensor(out[:,:,y::r,x::r], input[:,r*y+x::r*r,:,:]) + out = T.set_subtensor(out[:,:,y::r,x::r], input[:,r*y+x::r*r,:,:]) + return out + + +class ReflectLayer(lasagne.layers.Layer): + """Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f + """ + + def __init__(self, incoming, pad, batch_ndim=2, **kwargs): + super(ReflectLayer, self).__init__(incoming, **kwargs) + self.pad = pad + self.batch_ndim = batch_ndim + + def get_output_shape_for(self, input_shape): + output_shape = list(input_shape) + for k, p in enumerate(self.pad): + if output_shape[k + self.batch_ndim] is None: continue + output_shape[k + self.batch_ndim] += p * 2 + return tuple(output_shape) + + def get_output_for(self, x, **kwargs): + out = T.zeros(self.get_output_shape_for(x.shape)) + p0, p1 = self.pad + out = T.set_subtensor(out[:,:,:p0,p1:-p1], x[:,:,p0:0:-1,:]) + out = T.set_subtensor(out[:,:,-p0:,p1:-p1], x[:,:,-2:-(2+p0):-1,:]) + out = T.set_subtensor(out[:,:,p0:-p0,p1:-p1], x) + out = T.set_subtensor(out[:,:,:,:p1], out[:,:,:,(2*p1):p1:-1]) + out = T.set_subtensor(out[:,:,:,-p1:], out[:,:,:,-(p1+2):-(2*p1+2):-1]) return out @@ -270,17 +323,30 @@ def __init__(self): def last_layer(self): return list(self.network.values())[-1] - def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25): - conv = ConvLayer(input, units, filter_size, stride=stride, pad=pad, nonlinearity=None) - prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha)) + def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25, reuse=False): + clone = '0/'+name.split('/')[-1] + if reuse and clone+'x' in self.network: + extra = {'W': self.network[clone+'x'].W, 'b': self.network[clone+'x'].b} + else: + extra = {} + + padded = ReflectLayer(input, pad) if pad[0] > 0 and pad[1] > 0 else input + conv = ConvLayer(padded, units, filter_size, stride=stride, pad=0, nonlinearity=None, **extra) self.network[name+'x'] = conv - self.network[name+'>'] = prelu - return prelu + + if reuse and clone+'>' in self.network: + extra = {'alpha': self.network[clone+'>'].alpha} + else: + extra = {} + self.network[name+'>'] = lasagne.layers.ParametricRectifierLayer(conv, **extra) + return self.last_layer() def make_block(self, name, input, units): - self.make_layer(name+'-A', input, units, alpha=0.1) - # self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0) - return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer() + self.make_layer(name+'-A', input, units, alpha=0.25) + self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0) + if args.generator_residual: + self.network[name+'-R'] = ElemwiseSumLayer([input, self.last_layer()]) + return self.last_layer() def setup_generator(self, input, config): for k, v in config.items(): setattr(args, k, v) @@ -288,21 +354,22 @@ def setup_generator(self, input, config): units_iter = extend(args.generator_filters) units = next(units_iter) - self.make_layer('iter.0', input, units, filter_size=(7,7), pad=(3,3)) + self.make_layer('encode', input, units, filter_size=(7,7), pad=(3,3)) for i in range(0, args.generator_downscale): - self.make_layer('downscale%i'%i, self.last_layer(), next(units_iter), filter_size=(4,4), stride=(2,2)) + self.make_layer('%i/downscale'%i, self.last_layer(), next(units_iter), filter_size=4, stride=2, reuse=True) units = next(units_iter) for i in range(0, args.generator_blocks): - self.make_block('iter.%i'%(i+1), self.last_layer(), units) + self.make_block('default.%i'%i, self.last_layer(), units) for i in range(0, args.generator_upscale): u = next(units_iter) - self.make_layer('upscale%i.2'%i, self.last_layer(), u*4) - self.network['upscale%i.1'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2) + self.make_layer('%i/upscale.2'%i, self.last_layer(), u*4, reuse=True) + self.network['%i/upscale.1'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2) - self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(7,7), pad=(3,3), nonlinearity=None) + self.network['decode'] = ConvLayer(self.last_layer(), 3, filter_size=(7,7), pad=(3,3), nonlinearity=None) + self.network['out'] = self.last_layer() def setup_perceptual(self, input): """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. @@ -477,12 +544,10 @@ def show_progress(self, orign, scald, repro): self.imsave('valid/%s_%03i_reprod.png' % (args.model, i), repro[i]) def decay_learning_rate(self): - l_r, t_cur = args.learning_rate, 0 - - while True: + l_r = args.learning_rate + for t_cur in itertools.count(): yield l_r - t_cur += 1 - if t_cur % args.learning_period == 0: l_r *= args.learning_decay + if (t_cur+1) % args.learning_period == 0: l_r *= args.learning_decay def train(self): seed_size = args.batch_shape // args.zoom diff --git a/plugins/simple.py b/plugins/simple.py new file mode 100644 index 0000000..5e759e3 --- /dev/null +++ b/plugins/simple.py @@ -0,0 +1,16 @@ +import glob +import itertools + +import scipy.misc +import scipy.ndimage + + +def iterate_files(): + return itertools.cycle(glob.glob('data/*.jpg')) + +def load_original(filename): + return scipy.ndimage.imread(filename, mode='RGB') + +def load_seed(filename, original, zoom): + target_shape = (original.shape[0]//zoom, original.shape[1]//zoom) + return scipy.misc.imresize(original, target_shape, interp='bilinear')