Skip to content

Commit

Permalink
- rm output mode paramter (will alwas return numpy data and save file…
Browse files Browse the repository at this point in the history
…s if paths provided)
  • Loading branch information
MarcelRosier committed Jan 29, 2024
1 parent 5a63c34 commit 7eed04b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 27 deletions.
3 changes: 0 additions & 3 deletions brainles_aurora/inferer/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ class BaseConfig:
"""Base configuration for the Aurora model inferer.
Attributes:
output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE.
log_level (int | str, optional): Logging level. Defaults to logging.INFO.
"""

output_mode: DataMode = DataMode.NIFTI_FILE
log_level: int | str = logging.INFO


Expand All @@ -24,7 +22,6 @@ class AuroraInfererConfig(BaseConfig):
"""Configuration for the Aurora model inferer.
Attributes:
output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE.
log_level (int | str, optional): Logging level. Defaults to logging.INFO.
tta (bool, optional): Whether to apply test-time augmentations. Defaults to True.
sliding_window_batch_size (int, optional): Batch size for sliding window inference. Defaults to 1.
Expand Down
36 changes: 14 additions & 22 deletions brainles_aurora/inferer/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,11 @@ def _post_process(
Output.METASTASIS_NETWORK: enhancing_out,
}

def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]:
def _sliding_window_inference(self) -> Dict[str, np.ndarray]:
"""Perform sliding window inference using monai.inferers.SlidingWindowInferer.
Returns:
None | Dict[str, np.ndarray]: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned.
Dict[str, np.ndarray]: Post-processed data
"""
inferer = SlidingWindowInferer(
roi_size=self.config.crop_size, # = patch_size
Expand Down Expand Up @@ -461,15 +461,14 @@ def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]:
postprocessed_data = self._post_process(
onehot_model_outputs_CHWD=outputs,
)
if self.config.output_mode == DataMode.NUMPY:
self.log.info(
"Returning post-processed data as Dict of Numpy arrays"
)
return postprocessed_data
else:

# save data to fie if paths are provided
if any(self.output_file_mapping.values()):
self.log.info("Saving post-processed data as NIFTI files")
self._save_as_nifti(postproc_data=postprocessed_data)
return None

self.log.info("Returning post-processed data as Dict of Numpy arrays")
return postprocessed_data

def _configure_device(self) -> torch.device:
"""Configure the device for inference.
Expand Down Expand Up @@ -511,7 +510,7 @@ def infer(
log_file (str | Path | None, optional): _description_. Defaults to None.
Returns:
Dict[str, np.ndarray] | None: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned.
Dict[str, np.ndarray]: Post-processed data.
"""
# setup logger for inference run
if log_file:
Expand Down Expand Up @@ -545,18 +544,11 @@ def infer(
self.data_loader = self._get_data_loader()

# setup output file paths
if self.config.output_mode == DataMode.NIFTI_FILE:
# TODO add error handling to ensure file extensions present
if not segmentation_file:
default_segmentation_path = os.path.abspath("./segmentation.nii.gz")
self.log.warning(
f"No segmentation file name provided, using default path: {default_segmentation_path}"
)
self.output_file_mapping = {
Output.SEGMENTATION: segmentation_file or default_segmentation_path,
Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file,
Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file,
}
self.output_file_mapping = {
Output.SEGMENTATION: segmentation_file,
Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file,
Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file,
}

########
self.log.info(f"Running inference on device := {self.device}")
Expand Down
2 changes: 0 additions & 2 deletions segmentation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def cpu_nifti():
def gpu_np():
config = AuroraInfererConfig(
tta=False,
output_mode=DataMode.NUMPY,
) # disable tta for faster inference in this showcase

# If you don-t have a GPU that supports CUDA use the CPU version: AuroraInferer(config=config)
Expand All @@ -95,7 +94,6 @@ def gpu_output_np():
t1c=load_np_from_nifti(t1c),
t2=load_np_from_nifti(t2),
fla=load_np_from_nifti(fla),
output_mode=DataMode.NUMPY,
)
inferer = AuroraGPUInferer(
config=config,
Expand Down

0 comments on commit 7eed04b

Please sign in to comment.