Skip to content
This repository has been archived by the owner on Jan 2, 2021. It is now read-only.

Fixes for scipy misc #258

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions enhance.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#!/usr/bin/env python3
""" _ _
_ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___
| '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \
| | | | __/ |_| | | | (_| | | | __/ | | | | | | (_| | | | | (_| __/
|_| |_|\___|\__,_|_| \__,_|_| \___|_| |_|_| |_|\__,_|_| |_|\___\___|

""" _ _ .
_ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___ .
| '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \ .
| | | | __/ |_| | | | (_| | | | __/ | | | | | | (_| | | | | (_| __/ .
|_| |_|\___|\__,_|_| \__,_|_| \___|_| |_|_| |_|\__,_|_| |_|\___\___| .
"""
#
# Copyright (c) 2016, Alex J. Champandard.
Expand Down Expand Up @@ -111,7 +110,9 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))

# Scientific & Imaging Libraries
import numpy as np
import scipy.ndimage, scipy.misc, PIL.Image
import scipy.ndimage, PIL.Image

from scipy_misc import toimage, fromimage, imread # replacementment for scipy.misc

# Numeric Computing (GPU)
import theano, theano.tensor as T
Expand Down Expand Up @@ -185,8 +186,8 @@ def add_to_buffer(self, f):
seed.save(buffer, format='jpeg', quality=args.train_jpeg[0]+random.randrange(-rng, +rng))
seed = PIL.Image.open(buffer)

orig = scipy.misc.fromimage(orig).astype(np.float32)
seed = scipy.misc.fromimage(seed).astype(np.float32)
orig = fromimage(orig).astype(np.float32)
seed = fromimage(seed).astype(np.float32)

if args.train_noise is not None:
seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))
Expand Down Expand Up @@ -377,7 +378,7 @@ def cast(p): return p.get_value().astype(np.float16)
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \
['generator_upscale', 'generator_downscale']}

pickle.dump((config, params), bz2.open(self.get_filename(absolute=True), 'wb'))
print(' - Saved model as `{}` after training.'.format(self.get_filename()))

Expand Down Expand Up @@ -467,7 +468,7 @@ def __init__(self, loader):
print('{}'.format(ansi.ENDC))

def imsave(self, fn, img):
scipy.misc.toimage(np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0) * 255.0, cmin=0, cmax=255).save(fn)
toimage(np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0) * 255.0, cmin=0, cmax=255).save(fn)

def show_progress(self, orign, scald, repro):
os.makedirs('valid', exist_ok=True)
Expand Down Expand Up @@ -568,7 +569,7 @@ def process(self, original):
for i in range(3):
output[:,:,i] = self.match_histograms(output[:,:,i], original[:,:,i])

return scipy.misc.toimage(output, cmin=0, cmax=255)
return toimage(output, cmin=0, cmax=255)


if __name__ == "__main__":
Expand All @@ -580,7 +581,7 @@ def process(self, original):
enhancer = NeuralEnhancer(loader=False)
for filename in args.files:
print(filename, end=' ')
img = scipy.ndimage.imread(filename, mode='RGB')
img = imread(filename)
out = enhancer.process(img)
out.save(os.path.splitext(filename)[0]+'_ne%ix.png' % args.zoom)
print(flush=True)
Expand Down
239 changes: 239 additions & 0 deletions scipy_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# based on https://stackoverflow.com/a/57545205 and extended
import numpy as np
from PIL import Image

_errstr = "Mode is unknown or incompatible with input array shape."

def imread(filename, flatten=False, mode=None):
return fromimage(Image.open(filename), flatten=flatten, mode=mode)

def fromimage(im, flatten=False, mode=None):
"""
Return a copy of a PIL image as a numpy array.
Parameters
----------
im : PIL image
Input image.
flatten : bool
If true, convert the output to grey-scale.
mode : str, optional
Mode to convert image to, e.g. ``'RGB'``. See the Notes of the
`imread` docstring for more details.
Returns
-------
fromimage : ndarray
The different colour bands/channels are stored in the
third dimension, such that a grey-image is MxN, an
RGB-image MxNx3 and an RGBA-image MxNx4.
"""
if not Image.isImageType(im):
raise TypeError("Input is not a PIL image.")

if mode is not None:
if mode != im.mode:
im = im.convert(mode)
elif im.mode == 'P':
# Mode 'P' means there is an indexed "palette". If we leave the mode
# as 'P', then when we do `a = array(im)` below, `a` will be a 2-D
# containing the indices into the palette, and not a 3-D array
# containing the RGB or RGBA values.
if 'transparency' in im.info:
im = im.convert('RGBA')
else:
im = im.convert('RGB')

if flatten:
im = im.convert('F')
elif im.mode == '1':
# Workaround for crash in PIL. When im is 1-bit, the call array(im)
# can cause a seg. fault, or generate garbage. See
# https://github.com/scipy/scipy/issues/2138 and
# https://github.com/python-pillow/Pillow/issues/350.
#
# This converts im from a 1-bit image to an 8-bit image.
im = im.convert('L')

a = np.array(im)
return a


def bytescale(data, cmin=None, cmax=None, high=255, low=0):
"""
Byte scales an array (image).
Byte scaling means converting the input image to uint8 dtype and scaling
the range to ``(low, high)`` (default 0-255).
If the input image already has dtype uint8, no scaling is done.
This function is only available if Python Imaging Library (PIL) is installed.
Parameters
----------
data : ndarray
PIL image data array.
cmin : scalar, optional
Bias scaling of small values. Default is ``data.min()``.
cmax : scalar, optional
Bias scaling of large values. Default is ``data.max()``.
high : scalar, optional
Scale max value to `high`. Default is 255.
low : scalar, optional
Scale min value to `low`. Default is 0.
Returns
-------
img_array : uint8 ndarray
The byte-scaled array.
Examples
--------
>>> from scipy.misc import bytescale
>>> img = np.array([[ 91.06794177, 3.39058326, 84.4221549 ],
... [ 73.88003259, 80.91433048, 4.88878881],
... [ 51.53875334, 34.45808177, 27.5873488 ]])
>>> bytescale(img)
array([[255, 0, 236],
[205, 225, 4],
[140, 90, 70]], dtype=uint8)
>>> bytescale(img, high=200, low=100)
array([[200, 100, 192],
[180, 188, 102],
[155, 135, 128]], dtype=uint8)
>>> bytescale(img, cmin=0, cmax=255)
array([[91, 3, 84],
[74, 81, 5],
[52, 34, 28]], dtype=uint8)
"""
if data.dtype == np.uint8:
return data

if high > 255:
raise ValueError("`high` should be less than or equal to 255.")
if low < 0:
raise ValueError("`low` should be greater than or equal to 0.")
if high < low:
raise ValueError("`high` should be greater than or equal to `low`.")

if cmin is None:
cmin = data.min()
if cmax is None:
cmax = data.max()

cscale = cmax - cmin
if cscale < 0:
raise ValueError("`cmax` should be larger than `cmin`.")
elif cscale == 0:
cscale = 1

scale = float(high - low) / cscale
bytedata = (data - cmin) * scale + low
return (bytedata.clip(low, high) + 0.5).astype(np.uint8)


def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None,
mode=None, channel_axis=None):
"""Takes a numpy array and returns a PIL image.
This function is only available if Python Imaging Library (PIL) is installed.
The mode of the PIL image depends on the array shape and the `pal` and
`mode` keywords.
For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values
(from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode
is given as 'F' or 'I' in which case a float and/or integer array is made.
.. warning::
This function uses `bytescale` under the hood to rescale images to use
the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
It will also cast data for 2-D images to ``uint32`` for ``mode=None``
(which is the default).
Notes
-----
For 3-D arrays, the `channel_axis` argument tells which dimension of the
array holds the channel data.
For 3-D arrays if one of the dimensions is 3, the mode is 'RGB'
by default or 'YCbCr' if selected.
The numpy array must be either 2 dimensional or 3 dimensional.
"""
data = np.asarray(arr)
if np.iscomplexobj(data):
raise ValueError("Cannot convert a complex-valued array.")
shape = list(data.shape)
valid = len(shape) == 2 or ((len(shape) == 3) and
((3 in shape) or (4 in shape)))
if not valid:
raise ValueError("'arr' does not have a suitable array shape for "
"any mode.")
if len(shape) == 2:
shape = (shape[1], shape[0]) # columns show up first
if mode == 'F':
data32 = data.astype(np.float32)
image = Image.frombytes(mode, shape, data32.tostring())
return image
if mode in [None, 'L', 'P']:
bytedata = bytescale(data, high=high, low=low,
cmin=cmin, cmax=cmax)
image = Image.frombytes('L', shape, bytedata.tostring())
if pal is not None:
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
# Becomes a mode='P' automagically.
elif mode == 'P': # default gray-scale
pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] *
np.ones((3,), dtype=np.uint8)[np.newaxis, :])
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
return image
if mode == '1': # high input gives threshold for 1
bytedata = (data > high)
image = Image.frombytes('1', shape, bytedata.tostring())
return image
if cmin is None:
cmin = np.amin(np.ravel(data))
if cmax is None:
cmax = np.amax(np.ravel(data))
data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low
if mode == 'I':
data32 = data.astype(np.uint32)
image = Image.frombytes(mode, shape, data32.tostring())
else:
raise ValueError(_errstr)
return image

# if here then 3-d array with a 3 or a 4 in the shape length.
# Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
if channel_axis is None:
if (3 in shape):
ca = np.flatnonzero(np.asarray(shape) == 3)[0]
else:
ca = np.flatnonzero(np.asarray(shape) == 4)
if len(ca):
ca = ca[0]
else:
raise ValueError("Could not find channel dimension.")
else:
ca = channel_axis

numch = shape[ca]
if numch not in [3, 4]:
raise ValueError("Channel axis dimension is not valid.")

bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax)
if ca == 2:
strdata = bytedata.tostring()
shape = (shape[1], shape[0])
elif ca == 1:
strdata = np.transpose(bytedata, (0, 2, 1)).tostring()
shape = (shape[2], shape[0])
elif ca == 0:
strdata = np.transpose(bytedata, (1, 2, 0)).tostring()
shape = (shape[2], shape[1])
if mode is None:
if numch == 3:
mode = 'RGB'
else:
mode = 'RGBA'

if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']:
raise ValueError(_errstr)

if mode in ['RGB', 'YCbCr']:
if numch != 3:
raise ValueError("Invalid array shape for mode.")
if mode in ['RGBA', 'CMYK']:
if numch != 4:
raise ValueError("Invalid array shape for mode.")

# Here we know data and mode is correct
image = Image.frombytes(mode, shape, strdata)
return image