Skip to content

Commit

Permalink
🐛 remove model parameters in predict (#65)
Browse files Browse the repository at this point in the history
* remove model parameters in predict

---------

Co-authored-by: Bryn Lloyd <12702862+dyollb@users.noreply.github.com>
  • Loading branch information
dyollb and dyollb authored Nov 7, 2024
1 parent 288663f commit a6fe28f
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/segmantic/seg/monai_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,6 @@ def predict(
test_labels: Optional[list[Path]] = None,
output_dir: Path = None,
tissue_dict: dict[str, int] = None,
channels: tuple[int, ...] = (16, 32, 64, 128, 256),
strides: tuple[int, ...] = (2, 2, 2, 2),
dropout: float = 0.0,
spacing: Sequence[float] = [],
gpu_ids: list[int] = [],
) -> None:
Expand All @@ -568,9 +565,7 @@ def predict(
settings = json.load(json_file)
net: Net = Net.load_from_checkpoint(f"{model_file}", **settings)
else:
net = Net.load_from_checkpoint(
f"{model_file}", channels=channels, strides=strides, dropout=dropout
)
net = Net.load_from_checkpoint(f"{model_file}")
num_classes = net.num_classes

net.freeze()
Expand Down Expand Up @@ -765,7 +760,7 @@ def cross_validate(

for config_file in Path(config_files_dir).iterdir():
assert config_file.suffix in [".json", ".yml"], f"suffix: {config_file}"
is_json = config_file and config_file.suffix.lower() == ".json"
is_json = config_file.suffix.lower() == ".json"
dumps = partial(config.dumps, is_json=is_json)
loads = partial(config.loads, is_json=is_json)

Expand Down Expand Up @@ -823,9 +818,6 @@ def cross_validate(
test_images=test_images,
test_labels=test_labels,
tissue_dict=tissue_dict,
# channels=current_layers,
# strides=current_strides,
dropout=0.0,
spacing=[1, 1, 1],
gpu_ids=gpu_ids,
)
Expand Down

0 comments on commit a6fe28f

Please sign in to comment.