Skip to content

Commit

Permalink
Replace Click with Typer for CLI tools (#978)
Browse files Browse the repository at this point in the history
* Replace Click with Typer

* Fix entry points
  • Loading branch information
fepegar authored Oct 9, 2022
1 parent 5983f83 commit 3bef1e8
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 81 deletions.
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ project_urls =
install_requires =
Deprecated
SimpleITK!=2.0.*,!=2.1.1.1
click
humanize
nibabel
numpy>=1.15
scipy
torch>=1.1
tqdm
typer[all]
python_requires = >=3.7
include_package_data = True
zip_safe = False
Expand All @@ -54,9 +54,9 @@ where = src

[options.entry_points]
console_scripts =
tiohd=torchio.cli.print_info:main
tiotr=torchio.cli.apply_transform:main
torchio-transform=torchio.cli.apply_transform:main
tiohd=torchio.cli.print_info:app
tiotr=torchio.cli.apply_transform:app
torchio-transform=torchio.cli.apply_transform:app

[options.extras_require]
all =
Expand Down
114 changes: 66 additions & 48 deletions src/torchio/cli/apply_transform.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,63 @@
# pylint: disable=import-outside-toplevel
"""Console script for torchio."""
import sys

import click
from pathlib import Path

import typer
from rich.progress import Progress, SpinnerColumn, TextColumn

@click.command()
@click.argument('input-path', type=click.Path(exists=True))
@click.argument('transform-name', type=str)
@click.argument('output-path', type=click.Path())
@click.option(
'--kwargs', '-k',
type=str,
help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".',
)
@click.option(
'--imclass', '-c',
type=str,
default='ScalarImage',
help='Subclass of torchio.Image used to instantiate the image.',
)
@click.option(
'--seed', '-s',
type=int,
help='Seed for PyTorch random number generator.',
)
@click.option(
'--verbose/--no-verbose', '-v',
type=bool,
default=False,
help='Print random transform parameters.',
)

app = typer.Typer()


@app.command()
def main(
input_path,
transform_name,
output_path,
kwargs,
imclass,
seed,
verbose,
input_path: Path = typer.Argument( # noqa: B008
...,
exists=True,
file_okay=True,
dir_okay=True,
readable=True,
),
transform_name: str = typer.Argument(...), # noqa: B008
output_path: Path = typer.Argument( # noqa: B008
...,
file_okay=True,
dir_okay=False,
writable=True,
),
kwargs: str = typer.Option( # noqa: B008
None,
'--kwargs', '-k',
help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".',
),
imclass: str = typer.Option( # noqa: B008
'ScalarImage',
'--imclass', '-c',
help=(
'Name of the subclass of torchio.Image'
' that will be used to instantiate the image.'
),
),
seed: int = typer.Option( # noqa: B008
None,
'--seed', '-s',
help='Seed for PyTorch random number generator.',
),
verbose: bool = typer.Option( # noqa: B008
False,
help='Print random transform parameters.',
),
show_progress: bool = typer.Option( # noqa: B008
True,
'--show-progress/--hide-progress',
'-p/-P',
help='Show animations indicating progress.',
),
):
"""Apply transform to an image.
\b
Example:
$ torchio-transform -k "degrees=(-5,15) num_transforms=3" input.nrrd RandomMotion output.nii
$ tiotr input.nrrd RandomMotion output.nii "degrees=(-5,15) num_transforms=3" -v
""" # noqa: E501
# Imports are placed here so that the tool loads faster if not being run
import torch
Expand All @@ -61,14 +74,20 @@ def main(
transform = transform_class(**params_dict)
if seed is not None:
torch.manual_seed(seed)
apply_transform_to_file(
input_path,
transform,
output_path,
verbose=verbose,
class_=imclass,
)
return 0
with Progress(
SpinnerColumn(),
TextColumn('[progress.description]{task.description}'), # noqa: FS003
transient=True,
disable=not show_progress,
) as progress:
progress.add_task('Applying transform', total=1)
apply_transform_to_file(
input_path,
transform,
output_path,
verbose=verbose,
class_=imclass,
)


def get_params_dict_from_kwargs(kwargs):
Expand All @@ -88,5 +107,4 @@ def get_params_dict_from_kwargs(kwargs):


if __name__ == '__main__':
# pylint: disable=no-value-for-parameter
sys.exit(main()) # pragma: no cover
app()
44 changes: 31 additions & 13 deletions src/torchio/cli/print_info.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
# pylint: disable=import-outside-toplevel
"""Console script for torchio."""
import sys
from pathlib import Path

import click
import typer


@click.command()
@click.argument('input-path', type=click.Path(exists=True))
@click.option('--plot/--no-plot', '-p', default=False)
@click.option('--show/--no-show', '-s', default=False)
@click.option('--label/--scalar', '-l', default=False)
def main(input_path, plot, show, label):
app = typer.Typer()


@app.command()
def main(
input_path: Path = typer.Argument( # noqa: B008
...,
exists=True,
file_okay=True,
dir_okay=True,
readable=True,
),
plot: bool = typer.Option( # noqa: B008
False,
'--plot/--no-plot', '-p/-P',
help='Plot the image using Matplotlib or Pillow.',
),
show: bool = typer.Option( # noqa: B008
False,
'--show/--no-show', '-s/-S',
help='Show the image using specialized visualisation software.',
),
label: bool = typer.Option( # noqa: B008
False,
'--label/--scalar', '-l/-s',
help='Use torchio.LabelMap to instantiate the image.',
),
):
"""Print information about an image and, optionally, show it.
\b
Example:
$ tiohd input.nii.gz
"""
Expand All @@ -27,9 +47,7 @@ def main(input_path, plot, show, label):
image.plot()
if show:
image.show()
return 0


if __name__ == '__main__':
# pylint: disable=no-value-for-parameter
sys.exit(main()) # pragma: no cover
app()
25 changes: 9 additions & 16 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,40 @@
#!/usr/bin/env python
"""Tests for CLI tool package."""
from click.testing import CliRunner
from typer.testing import CliRunner
from torchio.cli import apply_transform
from torchio.cli import print_info

from .utils import TorchioTestCase


runner = CliRunner()


class TestCLI(TorchioTestCase):
"""Tests for CLI tool."""
def test_help(self):
"""Test the CLI."""
runner = CliRunner()
help_result = runner.invoke(apply_transform.main, ['--help'])
assert help_result.exit_code == 0
assert 'Show this message and exit.' in help_result.output

def test_cli_transform(self):
image = str(self.get_image_path('cli'))
runner = CliRunner()
args = [
image,
'RandomFlip',
'--seed', '0',
'--kwargs', 'axes=(0,1,2)',
'--hide-progress',
image,
]
result = runner.invoke(apply_transform.main, args)
result = runner.invoke(apply_transform.app, args)
assert result.exit_code == 0
assert result.output == ''
assert result.output.strip() == ''

def test_bad_transform(self):
ValueError
image = str(self.get_image_path('cli'))
runner = CliRunner()
args = [image, 'RandomRandom', image]
result = runner.invoke(apply_transform.main, args)
result = runner.invoke(apply_transform.app, args)
assert result.exit_code == 1

def test_cli_hd(self):
image = str(self.get_image_path('cli'))
runner = CliRunner()
args = [image]
result = runner.invoke(print_info.main, args)
result = runner.invoke(print_info.app, args)
assert result.exit_code == 0
assert result.output == 'ScalarImage(shape: (1, 10, 20, 30); spacing: (1.00, 1.00, 1.00); orientation: RAS+; dtype: torch.DoubleTensor; memory: 46.9 KiB)\n' # noqa: E501

0 comments on commit 3bef1e8

Please sign in to comment.