From f07a5018818e98d3200072c1385e94f7fa70f3dd Mon Sep 17 00:00:00 2001 From: ncullen93 Date: Mon, 16 Oct 2017 14:58:26 -0400 Subject: [PATCH] io casting --- ants/core/ants_image_io.py | 4 +++- ants/viz/plot.py | 10 ++++++++-- tests/timings_io.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ants/core/ants_image_io.py b/ants/core/ants_image_io.py index d536c8cb..a7db90b3 100644 --- a/ants/core/ants_image_io.py +++ b/ants/core/ants_image_io.py @@ -18,6 +18,7 @@ import os import json import numpy as np +import warnings from . import ants_image as iio from .. import utils @@ -401,7 +402,8 @@ def image_read(filename, dimension=None, pixeltype='float'): ndim = dimension if ptype in _unsupported_ptypes: - raise ValueError('unsupported pixeltype %s' % ptype) + #warnings.warn('Casting image from unsupported type \'%s\' to closest supported type \'%s\'' % (ptype, _unsupported_ptype_map[ptype])) + ptype = _unsupported_ptype_map.get(ptype, 'unsupported') libfn = utils.get_lib_fn(_image_read_dict[pclass][ptype][ndim]) itk_pointer = libfn(filename) diff --git a/ants/viz/plot.py b/ants/viz/plot.py index 159a1135..5c145c78 100644 --- a/ants/viz/plot.py +++ b/ants/viz/plot.py @@ -195,7 +195,7 @@ def plot(image, overlay=None, cmap='Greys_r', overlay_cmap='jet', overlay_alpha= def plot_directory(directory, recursive=False, regex='*', - save_prefix='', save_suffix='', **kwargs): + save_prefix='', save_suffix='', axis=None, **kwargs): """ Create and save an ANTsPy plot for every image matching a given regular expression in a directory, optionally recursively. This is a good function @@ -254,8 +254,14 @@ def has_acceptable_suffix(fname): fname = '%s%s' % (save_prefix, fname) save_fname = os.path.join(root, fname) img = iio2.image_read(load_fname) + + if axis is None: + axis_range = [i for i in range(img.dimension)] + else: + axis_range = axis if isinstance(axis,(list,tuple)) else [axis] + if img.dimension > 2: - for axis_idx in range(img.dimension): + for axis_idx in axis_range: filename = save_fname.replace('.png', '_axis%i.png' % axis_idx) ncol = int(math.sqrt(img.shape[axis_idx])) plot(img, axis=axis_idx, nslices=img.shape[axis_idx], ncol=ncol, diff --git a/tests/timings_io.py b/tests/timings_io.py index d823e58e..f044eb77 100644 --- a/tests/timings_io.py +++ b/tests/timings_io.py @@ -41,7 +41,7 @@ def test_itk(): def test_ants(): for img_path in img_paths: - array = ants.image_read(img_path).numpy() + array = ants.image_read(img_path, pixeltype='float').numpy() nib_start = time.time() for i in range(N_TRIALS):