diff --git a/enhance.py b/enhance.py index 585d85e..8b41b7c 100755 --- a/enhance.py +++ b/enhance.py @@ -47,6 +47,7 @@ add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.') add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.') add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level & range in preproc.') +add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.') add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.') add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.') @@ -128,6 +129,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1])) print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC)) +def confirm_pairs(list1, list2): + new_list1 = [] + new_list2 = [] + cur1 = 0 + cur2 = 0 + len1 = len(list1) + len2 = len(list2) + while(cur1 < len1 and cur2 < len2): + base1 = os.path.basename(list1[cur1]) + base2 = os.path.basename(list2[cur2]) + if base1 == base2: + new_list1.append(list1[cur1]) + new_list2.append(list2[cur2]) + cur1 = cur1 + 1 + cur2 = cur2 + 1 + elif base1 < base2: + # continue to look on list1, don't iterate list2 + cur1 = cur1 + 1 + else: + cur2 = cur2 + 1 + print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2))) + return new_list1, new_list2 #====================================================================================================================== # Image Processing @@ -143,11 +166,17 @@ def __init__(self): self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32) self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32) - self.files = glob.glob(args.train) + self.files = sorted(glob.glob(args.train)) if len(self.files) == 0: error("There were no files found to train from searching for `{}`".format(args.train), " - Try putting all your images in one folder and using `--train=data/*.jpg`") + if args.seeds: + self.seeds = sorted(glob.glob(args.seeds)) + self.files, self.seeds = confirm_pairs(self.files, self.seeds) + else: + self.seeds = False + self.available = set(range(args.buffer_size)) self.ready = set() @@ -156,11 +185,14 @@ def __init__(self): def run(self): while True: - random.shuffle(self.files) - for f in self.files: - self.add_to_buffer(f) + indices = list(range(0, len(self.files))) + random.shuffle(indices) + + for file_index in indices: + self.add_to_buffer(file_index) - def add_to_buffer(self, f): + def add_to_buffer(self, file_index): + f = self.files[file_index] filename = os.path.join(self.cwd, f) try: orig = PIL.Image.open(filename).convert('RGB') @@ -172,14 +204,37 @@ def add_to_buffer(self, f): except Exception as e: warn('Could not load `{}` as image.'.format(filename), ' - Try fixing or removing the file before next run.') - self.files.remove(f) + del self.files[file_index] + if self.seeds: + del self.seeds[file_index] return - seed = orig - if args.train_blur is not None: - seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2))) - if args.zoom > 1: - seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS) + # determine seed + if self.seeds: + # file based seed + f = self.seeds[file_index] + filename = os.path.join(self.cwd, f) + try: + seed = PIL.Image.open(filename).convert('RGB') + if any(s < self.seed_shape for s in seed.size): + raise ValueError('Image is too small for seed size (found {}, expected {})'.format(seed.size, self.seed_shape)) + except Exception as e: + warn('Could not load `{}` as seed image.'.format(filename), + ' - Try fixing or removing the file before next run. ({})'.format(e)) + del self.files[file_index] + del self.seeds[file_index] + return + + else: + # synthetic seed + seed = orig + # optionally blur before scaling + if args.train_blur is not None: + seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2))) + # seed is scaled down version of original + if args.zoom > 1: + seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS) + if len(args.train_jpeg) > 0: buffer, rng = io.BytesIO(), args.train_jpeg[-1] if len(args.train_jpeg) > 1 else 15 seed.save(buffer, format='jpeg', quality=args.train_jpeg+random.randrange(-rng, +rng))