Skip to content
This repository has been archived by the owner on Jan 2, 2021. It is now read-only.

Commit

Permalink
Added -seeds option to provide seeds from disk
Browse files Browse the repository at this point in the history
  • Loading branch information
dribnet committed Nov 17, 2016
1 parent 9d2aa3c commit 400b63a
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.')
add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.')
add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level & range in preproc.')
add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
Expand Down Expand Up @@ -128,6 +129,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))

print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))

def confirm_pairs(list1, list2):
new_list1 = []
new_list2 = []
cur1 = 0
cur2 = 0
len1 = len(list1)
len2 = len(list2)
while(cur1 < len1 and cur2 < len2):
base1 = os.path.basename(list1[cur1])
base2 = os.path.basename(list2[cur2])
if base1 == base2:
new_list1.append(list1[cur1])
new_list2.append(list2[cur2])
cur1 = cur1 + 1
cur2 = cur2 + 1
elif base1 < base2:
# continue to look on list1, don't iterate list2
cur1 = cur1 + 1
else:
cur2 = cur2 + 1
print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2)))
return new_list1, new_list2

#======================================================================================================================
# Image Processing
Expand All @@ -143,11 +166,17 @@ def __init__(self):

self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
self.files = glob.glob(args.train)
self.files = sorted(glob.glob(args.train))
if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
" - Try putting all your images in one folder and using `--train=data/*.jpg`")

if args.seeds:
self.seeds = sorted(glob.glob(args.seeds))
self.files, self.seeds = confirm_pairs(self.files, self.seeds)
else:
self.seeds = False

self.available = set(range(args.buffer_size))
self.ready = set()

Expand All @@ -156,11 +185,14 @@ def __init__(self):

def run(self):
while True:
random.shuffle(self.files)
for f in self.files:
self.add_to_buffer(f)
indices = list(range(0, len(self.files)))
random.shuffle(indices)

for file_index in indices:
self.add_to_buffer(file_index)

def add_to_buffer(self, f):
def add_to_buffer(self, file_index):
f = self.files[file_index]
filename = os.path.join(self.cwd, f)
try:
orig = PIL.Image.open(filename).convert('RGB')
Expand All @@ -172,14 +204,37 @@ def add_to_buffer(self, f):
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
self.files.remove(f)
del self.files[file_index]
if self.seeds:
del self.seeds[file_index]
return

seed = orig
if args.train_blur is not None:
seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2)))
if args.zoom > 1:
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)
# determine seed
if self.seeds:
# file based seed
f = self.seeds[file_index]
filename = os.path.join(self.cwd, f)
try:
seed = PIL.Image.open(filename).convert('RGB')
if any(s < self.seed_shape for s in seed.size):
raise ValueError('Image is too small for seed size (found {}, expected {})'.format(seed.size, self.seed_shape))
except Exception as e:
warn('Could not load `{}` as seed image.'.format(filename),
' - Try fixing or removing the file before next run. ({})'.format(e))
del self.files[file_index]
del self.seeds[file_index]
return

else:
# synthetic seed
seed = orig
# optionally blur before scaling
if args.train_blur is not None:
seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2)))
# seed is scaled down version of original
if args.zoom > 1:
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)

if len(args.train_jpeg) > 0:
buffer, rng = io.BytesIO(), args.train_jpeg[-1] if len(args.train_jpeg) > 1 else 15
seed.save(buffer, format='jpeg', quality=args.train_jpeg+random.randrange(-rng, +rng))
Expand Down

0 comments on commit 400b63a

Please sign in to comment.