Skip to content

Commit

Permalink
fix static type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
dyollb committed Apr 2, 2024
1 parent b88c1d3 commit b305c55
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/segmantic/commands/monai_unet_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def predict(
None, "--tissue-list", "-t", help="label descriptors in iSEG format"
),
results_dir: Path = typer.Option(
None, "--results-dir", "-r", help="output directory"
..., "--results-dir", "-r", help="output directory"
),
spacing: list[float] = typer.Option(
[], "--spacing", help="if specified, the image is first resampled"
Expand Down
6 changes: 3 additions & 3 deletions src/segmantic/seg/monai_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
self.best_val_dice = 0.0
self.best_val_epoch = 0.0
self.validation_step_outputs: list[dict] = []
self.training_step_outputs = []
self.training_step_outputs: list[torch.Tensor] = []

@property
def num_channels(self):
Expand Down Expand Up @@ -591,8 +591,8 @@ def train(
def predict(
model_file: Path,
test_images: list[Path],
output_dir: Path,
test_labels: Optional[list[Path]] = None,
output_dir: Path = None,
tissue_dict: dict[str, int] = None,
channels: tuple[int, ...] = (16, 32, 64, 128, 256),
strides: tuple[int, ...] = (2, 2, 2, 2),
Expand Down Expand Up @@ -750,7 +750,7 @@ def to_one_hot(x):
delimiter=",",
)

if test_labels:
if test_labels and confusion is not None:
plot_confusion_matrix(
confusion,
tissue_names,
Expand Down

0 comments on commit b305c55

Please sign in to comment.