diff --git a/enhance.py b/enhance.py index becb851..3251535 100755 --- a/enhance.py +++ b/enhance.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -""" _ _ - _ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___ - | '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \ - | | | | __/ |_| | | | (_| | | | __/ | | | | | | (_| | | | | (_| __/ - |_| |_|\___|\__,_|_| \__,_|_| \___|_| |_|_| |_|\__,_|_| |_|\___\___| - +""" _ _ . + _ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___ . + | '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \ . + | | | | __/ |_| | | | (_| | | | __/ | | | | | | (_| | | | | (_| __/ . + |_| |_|\___|\__,_|_| \__,_|_| \___|_| |_|_| |_|\__,_|_| |_|\___\___| . """ # # Copyright (c) 2016, Alex J. Champandard. @@ -111,7 +110,9 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1])) # Scientific & Imaging Libraries import numpy as np -import scipy.ndimage, scipy.misc, PIL.Image +import scipy.ndimage, PIL.Image + +from scipy_misc import toimage, fromimage, imread # replacementment for scipy.misc # Numeric Computing (GPU) import theano, theano.tensor as T @@ -185,8 +186,8 @@ def add_to_buffer(self, f): seed.save(buffer, format='jpeg', quality=args.train_jpeg[0]+random.randrange(-rng, +rng)) seed = PIL.Image.open(buffer) - orig = scipy.misc.fromimage(orig).astype(np.float32) - seed = scipy.misc.fromimage(seed).astype(np.float32) + orig = fromimage(orig).astype(np.float32) + seed = fromimage(seed).astype(np.float32) if args.train_noise is not None: seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1)) @@ -377,7 +378,7 @@ def cast(p): return p.get_value().astype(np.float16) params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()} config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \ ['generator_upscale', 'generator_downscale']} - + pickle.dump((config, params), bz2.open(self.get_filename(absolute=True), 'wb')) print(' - Saved model as `{}` after training.'.format(self.get_filename())) @@ -467,7 +468,7 @@ def __init__(self, loader): print('{}'.format(ansi.ENDC)) def imsave(self, fn, img): - scipy.misc.toimage(np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0) * 255.0, cmin=0, cmax=255).save(fn) + toimage(np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0) * 255.0, cmin=0, cmax=255).save(fn) def show_progress(self, orign, scald, repro): os.makedirs('valid', exist_ok=True) @@ -568,7 +569,7 @@ def process(self, original): for i in range(3): output[:,:,i] = self.match_histograms(output[:,:,i], original[:,:,i]) - return scipy.misc.toimage(output, cmin=0, cmax=255) + return toimage(output, cmin=0, cmax=255) if __name__ == "__main__": @@ -580,7 +581,7 @@ def process(self, original): enhancer = NeuralEnhancer(loader=False) for filename in args.files: print(filename, end=' ') - img = scipy.ndimage.imread(filename, mode='RGB') + img = imread(filename) out = enhancer.process(img) out.save(os.path.splitext(filename)[0]+'_ne%ix.png' % args.zoom) print(flush=True) diff --git a/scipy_misc.py b/scipy_misc.py new file mode 100644 index 0000000..cc870e1 --- /dev/null +++ b/scipy_misc.py @@ -0,0 +1,239 @@ +# based on https://stackoverflow.com/a/57545205 and extended +import numpy as np +from PIL import Image + +_errstr = "Mode is unknown or incompatible with input array shape." + +def imread(filename, flatten=False, mode=None): + return fromimage(Image.open(filename), flatten=flatten, mode=mode) + +def fromimage(im, flatten=False, mode=None): + """ + Return a copy of a PIL image as a numpy array. + Parameters + ---------- + im : PIL image + Input image. + flatten : bool + If true, convert the output to grey-scale. + mode : str, optional + Mode to convert image to, e.g. ``'RGB'``. See the Notes of the + `imread` docstring for more details. + Returns + ------- + fromimage : ndarray + The different colour bands/channels are stored in the + third dimension, such that a grey-image is MxN, an + RGB-image MxNx3 and an RGBA-image MxNx4. + """ + if not Image.isImageType(im): + raise TypeError("Input is not a PIL image.") + + if mode is not None: + if mode != im.mode: + im = im.convert(mode) + elif im.mode == 'P': + # Mode 'P' means there is an indexed "palette". If we leave the mode + # as 'P', then when we do `a = array(im)` below, `a` will be a 2-D + # containing the indices into the palette, and not a 3-D array + # containing the RGB or RGBA values. + if 'transparency' in im.info: + im = im.convert('RGBA') + else: + im = im.convert('RGB') + + if flatten: + im = im.convert('F') + elif im.mode == '1': + # Workaround for crash in PIL. When im is 1-bit, the call array(im) + # can cause a seg. fault, or generate garbage. See + # https://github.com/scipy/scipy/issues/2138 and + # https://github.com/python-pillow/Pillow/issues/350. + # + # This converts im from a 1-bit image to an 8-bit image. + im = im.convert('L') + + a = np.array(im) + return a + + +def bytescale(data, cmin=None, cmax=None, high=255, low=0): + """ + Byte scales an array (image). + Byte scaling means converting the input image to uint8 dtype and scaling + the range to ``(low, high)`` (default 0-255). + If the input image already has dtype uint8, no scaling is done. + This function is only available if Python Imaging Library (PIL) is installed. + Parameters + ---------- + data : ndarray + PIL image data array. + cmin : scalar, optional + Bias scaling of small values. Default is ``data.min()``. + cmax : scalar, optional + Bias scaling of large values. Default is ``data.max()``. + high : scalar, optional + Scale max value to `high`. Default is 255. + low : scalar, optional + Scale min value to `low`. Default is 0. + Returns + ------- + img_array : uint8 ndarray + The byte-scaled array. + Examples + -------- + >>> from scipy.misc import bytescale + >>> img = np.array([[ 91.06794177, 3.39058326, 84.4221549 ], + ... [ 73.88003259, 80.91433048, 4.88878881], + ... [ 51.53875334, 34.45808177, 27.5873488 ]]) + >>> bytescale(img) + array([[255, 0, 236], + [205, 225, 4], + [140, 90, 70]], dtype=uint8) + >>> bytescale(img, high=200, low=100) + array([[200, 100, 192], + [180, 188, 102], + [155, 135, 128]], dtype=uint8) + >>> bytescale(img, cmin=0, cmax=255) + array([[91, 3, 84], + [74, 81, 5], + [52, 34, 28]], dtype=uint8) + """ + if data.dtype == np.uint8: + return data + + if high > 255: + raise ValueError("`high` should be less than or equal to 255.") + if low < 0: + raise ValueError("`low` should be greater than or equal to 0.") + if high < low: + raise ValueError("`high` should be greater than or equal to `low`.") + + if cmin is None: + cmin = data.min() + if cmax is None: + cmax = data.max() + + cscale = cmax - cmin + if cscale < 0: + raise ValueError("`cmax` should be larger than `cmin`.") + elif cscale == 0: + cscale = 1 + + scale = float(high - low) / cscale + bytedata = (data - cmin) * scale + low + return (bytedata.clip(low, high) + 0.5).astype(np.uint8) + + +def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, + mode=None, channel_axis=None): + """Takes a numpy array and returns a PIL image. + This function is only available if Python Imaging Library (PIL) is installed. + The mode of the PIL image depends on the array shape and the `pal` and + `mode` keywords. + For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values + (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode + is given as 'F' or 'I' in which case a float and/or integer array is made. + .. warning:: + This function uses `bytescale` under the hood to rescale images to use + the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. + It will also cast data for 2-D images to ``uint32`` for ``mode=None`` + (which is the default). + Notes + ----- + For 3-D arrays, the `channel_axis` argument tells which dimension of the + array holds the channel data. + For 3-D arrays if one of the dimensions is 3, the mode is 'RGB' + by default or 'YCbCr' if selected. + The numpy array must be either 2 dimensional or 3 dimensional. + """ + data = np.asarray(arr) + if np.iscomplexobj(data): + raise ValueError("Cannot convert a complex-valued array.") + shape = list(data.shape) + valid = len(shape) == 2 or ((len(shape) == 3) and + ((3 in shape) or (4 in shape))) + if not valid: + raise ValueError("'arr' does not have a suitable array shape for " + "any mode.") + if len(shape) == 2: + shape = (shape[1], shape[0]) # columns show up first + if mode == 'F': + data32 = data.astype(np.float32) + image = Image.frombytes(mode, shape, data32.tostring()) + return image + if mode in [None, 'L', 'P']: + bytedata = bytescale(data, high=high, low=low, + cmin=cmin, cmax=cmax) + image = Image.frombytes('L', shape, bytedata.tostring()) + if pal is not None: + image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) + # Becomes a mode='P' automagically. + elif mode == 'P': # default gray-scale + pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] * + np.ones((3,), dtype=np.uint8)[np.newaxis, :]) + image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) + return image + if mode == '1': # high input gives threshold for 1 + bytedata = (data > high) + image = Image.frombytes('1', shape, bytedata.tostring()) + return image + if cmin is None: + cmin = np.amin(np.ravel(data)) + if cmax is None: + cmax = np.amax(np.ravel(data)) + data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low + if mode == 'I': + data32 = data.astype(np.uint32) + image = Image.frombytes(mode, shape, data32.tostring()) + else: + raise ValueError(_errstr) + return image + + # if here then 3-d array with a 3 or a 4 in the shape length. + # Check for 3 in datacube shape --- 'RGB' or 'YCbCr' + if channel_axis is None: + if (3 in shape): + ca = np.flatnonzero(np.asarray(shape) == 3)[0] + else: + ca = np.flatnonzero(np.asarray(shape) == 4) + if len(ca): + ca = ca[0] + else: + raise ValueError("Could not find channel dimension.") + else: + ca = channel_axis + + numch = shape[ca] + if numch not in [3, 4]: + raise ValueError("Channel axis dimension is not valid.") + + bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax) + if ca == 2: + strdata = bytedata.tostring() + shape = (shape[1], shape[0]) + elif ca == 1: + strdata = np.transpose(bytedata, (0, 2, 1)).tostring() + shape = (shape[2], shape[0]) + elif ca == 0: + strdata = np.transpose(bytedata, (1, 2, 0)).tostring() + shape = (shape[2], shape[1]) + if mode is None: + if numch == 3: + mode = 'RGB' + else: + mode = 'RGBA' + + if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']: + raise ValueError(_errstr) + + if mode in ['RGB', 'YCbCr']: + if numch != 3: + raise ValueError("Invalid array shape for mode.") + if mode in ['RGBA', 'CMYK']: + if numch != 4: + raise ValueError("Invalid array shape for mode.") + + # Here we know data and mode is correct + image = Image.frombytes(mode, shape, strdata) + return image \ No newline at end of file