Skip to content

Commit

Permalink
io casting
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Oct 16, 2017
1 parent f43df4d commit f07a501
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 3 additions & 1 deletion ants/core/ants_image_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import json
import numpy as np
import warnings

from . import ants_image as iio
from .. import utils
Expand Down Expand Up @@ -401,7 +402,8 @@ def image_read(filename, dimension=None, pixeltype='float'):
ndim = dimension

if ptype in _unsupported_ptypes:
raise ValueError('unsupported pixeltype %s' % ptype)
#warnings.warn('Casting image from unsupported type \'%s\' to closest supported type \'%s\'' % (ptype, _unsupported_ptype_map[ptype]))
ptype = _unsupported_ptype_map.get(ptype, 'unsupported')

libfn = utils.get_lib_fn(_image_read_dict[pclass][ptype][ndim])
itk_pointer = libfn(filename)
Expand Down
10 changes: 8 additions & 2 deletions ants/viz/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def plot(image, overlay=None, cmap='Greys_r', overlay_cmap='jet', overlay_alpha=


def plot_directory(directory, recursive=False, regex='*',
save_prefix='', save_suffix='', **kwargs):
save_prefix='', save_suffix='', axis=None, **kwargs):
"""
Create and save an ANTsPy plot for every image matching a given regular
expression in a directory, optionally recursively. This is a good function
Expand Down Expand Up @@ -254,8 +254,14 @@ def has_acceptable_suffix(fname):
fname = '%s%s' % (save_prefix, fname)
save_fname = os.path.join(root, fname)
img = iio2.image_read(load_fname)

if axis is None:
axis_range = [i for i in range(img.dimension)]
else:
axis_range = axis if isinstance(axis,(list,tuple)) else [axis]

if img.dimension > 2:
for axis_idx in range(img.dimension):
for axis_idx in axis_range:
filename = save_fname.replace('.png', '_axis%i.png' % axis_idx)
ncol = int(math.sqrt(img.shape[axis_idx]))
plot(img, axis=axis_idx, nslices=img.shape[axis_idx], ncol=ncol,
Expand Down
2 changes: 1 addition & 1 deletion tests/timings_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_itk():

def test_ants():
for img_path in img_paths:
array = ants.image_read(img_path).numpy()
array = ants.image_read(img_path, pixeltype='float').numpy()

nib_start = time.time()
for i in range(N_TRIALS):
Expand Down

0 comments on commit f07a501

Please sign in to comment.