diff --git a/setup.cfg b/setup.cfg index 79cacdd08..d7417fb44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,13 +38,13 @@ project_urls = install_requires = Deprecated SimpleITK!=2.0.*,!=2.1.1.1 - click humanize nibabel numpy>=1.15 scipy torch>=1.1 tqdm + typer[all] python_requires = >=3.7 include_package_data = True zip_safe = False @@ -54,9 +54,9 @@ where = src [options.entry_points] console_scripts = - tiohd=torchio.cli.print_info:main - tiotr=torchio.cli.apply_transform:main - torchio-transform=torchio.cli.apply_transform:main + tiohd=torchio.cli.print_info:app + tiotr=torchio.cli.apply_transform:app + torchio-transform=torchio.cli.apply_transform:app [options.extras_require] all = diff --git a/src/torchio/cli/apply_transform.py b/src/torchio/cli/apply_transform.py index 631a5e374..8b012b109 100644 --- a/src/torchio/cli/apply_transform.py +++ b/src/torchio/cli/apply_transform.py @@ -1,50 +1,63 @@ # pylint: disable=import-outside-toplevel -"""Console script for torchio.""" -import sys -import click +from pathlib import Path +import typer +from rich.progress import Progress, SpinnerColumn, TextColumn -@click.command() -@click.argument('input-path', type=click.Path(exists=True)) -@click.argument('transform-name', type=str) -@click.argument('output-path', type=click.Path()) -@click.option( - '--kwargs', '-k', - type=str, - help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".', -) -@click.option( - '--imclass', '-c', - type=str, - default='ScalarImage', - help='Subclass of torchio.Image used to instantiate the image.', -) -@click.option( - '--seed', '-s', - type=int, - help='Seed for PyTorch random number generator.', -) -@click.option( - '--verbose/--no-verbose', '-v', - type=bool, - default=False, - help='Print random transform parameters.', -) + +app = typer.Typer() + + +@app.command() def main( - input_path, - transform_name, - output_path, - kwargs, - imclass, - seed, - verbose, + input_path: Path = typer.Argument( # noqa: B008 + ..., + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + ), + transform_name: str = typer.Argument(...), # noqa: B008 + output_path: Path = typer.Argument( # noqa: B008 + ..., + file_okay=True, + dir_okay=False, + writable=True, + ), + kwargs: str = typer.Option( # noqa: B008 + None, + '--kwargs', '-k', + help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".', + ), + imclass: str = typer.Option( # noqa: B008 + 'ScalarImage', + '--imclass', '-c', + help=( + 'Name of the subclass of torchio.Image' + ' that will be used to instantiate the image.' + ), + ), + seed: int = typer.Option( # noqa: B008 + None, + '--seed', '-s', + help='Seed for PyTorch random number generator.', + ), + verbose: bool = typer.Option( # noqa: B008 + False, + help='Print random transform parameters.', + ), + show_progress: bool = typer.Option( # noqa: B008 + True, + '--show-progress/--hide-progress', + '-p/-P', + help='Show animations indicating progress.', + ), ): """Apply transform to an image. - \b Example: - $ torchio-transform -k "degrees=(-5,15) num_transforms=3" input.nrrd RandomMotion output.nii + $ tiotr input.nrrd RandomMotion output.nii "degrees=(-5,15) num_transforms=3" -v """ # noqa: E501 # Imports are placed here so that the tool loads faster if not being run import torch @@ -61,14 +74,20 @@ def main( transform = transform_class(**params_dict) if seed is not None: torch.manual_seed(seed) - apply_transform_to_file( - input_path, - transform, - output_path, - verbose=verbose, - class_=imclass, - ) - return 0 + with Progress( + SpinnerColumn(), + TextColumn('[progress.description]{task.description}'), # noqa: FS003 + transient=True, + disable=not show_progress, + ) as progress: + progress.add_task('Applying transform', total=1) + apply_transform_to_file( + input_path, + transform, + output_path, + verbose=verbose, + class_=imclass, + ) def get_params_dict_from_kwargs(kwargs): @@ -88,5 +107,4 @@ def get_params_dict_from_kwargs(kwargs): if __name__ == '__main__': - # pylint: disable=no-value-for-parameter - sys.exit(main()) # pragma: no cover + app() diff --git a/src/torchio/cli/print_info.py b/src/torchio/cli/print_info.py index 1cb50a203..2ba658ba2 100644 --- a/src/torchio/cli/print_info.py +++ b/src/torchio/cli/print_info.py @@ -1,19 +1,39 @@ # pylint: disable=import-outside-toplevel -"""Console script for torchio.""" -import sys +from pathlib import Path -import click +import typer -@click.command() -@click.argument('input-path', type=click.Path(exists=True)) -@click.option('--plot/--no-plot', '-p', default=False) -@click.option('--show/--no-show', '-s', default=False) -@click.option('--label/--scalar', '-l', default=False) -def main(input_path, plot, show, label): +app = typer.Typer() + + +@app.command() +def main( + input_path: Path = typer.Argument( # noqa: B008 + ..., + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + ), + plot: bool = typer.Option( # noqa: B008 + False, + '--plot/--no-plot', '-p/-P', + help='Plot the image using Matplotlib or Pillow.', + ), + show: bool = typer.Option( # noqa: B008 + False, + '--show/--no-show', '-s/-S', + help='Show the image using specialized visualisation software.', + ), + label: bool = typer.Option( # noqa: B008 + False, + '--label/--scalar', '-l/-s', + help='Use torchio.LabelMap to instantiate the image.', + ), +): """Print information about an image and, optionally, show it. - \b Example: $ tiohd input.nii.gz """ @@ -27,9 +47,7 @@ def main(input_path, plot, show, label): image.plot() if show: image.show() - return 0 if __name__ == '__main__': - # pylint: disable=no-value-for-parameter - sys.exit(main()) # pragma: no cover + app() diff --git a/tests/test_cli.py b/tests/test_cli.py index 2d24663fb..2b3fceeec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,47 +1,40 @@ #!/usr/bin/env python """Tests for CLI tool package.""" -from click.testing import CliRunner +from typer.testing import CliRunner from torchio.cli import apply_transform from torchio.cli import print_info from .utils import TorchioTestCase +runner = CliRunner() + + class TestCLI(TorchioTestCase): - """Tests for CLI tool.""" - def test_help(self): - """Test the CLI.""" - runner = CliRunner() - help_result = runner.invoke(apply_transform.main, ['--help']) - assert help_result.exit_code == 0 - assert 'Show this message and exit.' in help_result.output def test_cli_transform(self): image = str(self.get_image_path('cli')) - runner = CliRunner() args = [ image, 'RandomFlip', '--seed', '0', '--kwargs', 'axes=(0,1,2)', + '--hide-progress', image, ] - result = runner.invoke(apply_transform.main, args) + result = runner.invoke(apply_transform.app, args) assert result.exit_code == 0 - assert result.output == '' + assert result.output.strip() == '' def test_bad_transform(self): - ValueError image = str(self.get_image_path('cli')) - runner = CliRunner() args = [image, 'RandomRandom', image] - result = runner.invoke(apply_transform.main, args) + result = runner.invoke(apply_transform.app, args) assert result.exit_code == 1 def test_cli_hd(self): image = str(self.get_image_path('cli')) - runner = CliRunner() args = [image] - result = runner.invoke(print_info.main, args) + result = runner.invoke(print_info.app, args) assert result.exit_code == 0 assert result.output == 'ScalarImage(shape: (1, 10, 20, 30); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.DoubleTensor; memory: 46.9 KiB)\n' # noqa: E501