From ac4824c6416d3fbd1e4b38f0730b614f1a1cc5b8 Mon Sep 17 00:00:00 2001 From: Rudolph Pienaar Date: Thu, 25 Apr 2024 10:14:09 -0400 Subject: [PATCH] Start improving telemetry tracking --- spleenseg/core/neuralnet.py | 134 ++++++++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 28 deletions(-) diff --git a/spleenseg/core/neuralnet.py b/spleenseg/core/neuralnet.py index 00697b4..efba6fa 100644 --- a/spleenseg/core/neuralnet.py +++ b/spleenseg/core/neuralnet.py @@ -18,10 +18,23 @@ from typing import Any, Sequence import numpy as np +import nibabel as nib +from nibabel.nifti1 import Nifti1Image from spleenseg.transforms import transforms from spleenseg.models import data from spleenseg.plotting import plotting +import pudb +import shutil + + +def dictFirstKeyValue_getFromList(l: list[dict[str, Any]], i: int = 0) -> Any: + val: Any = None + if len(l): + if l[i]: + key: str = list(l[i].keys())[0] + val = l[i][key] + return val def tensor_desc( @@ -64,6 +77,8 @@ def __init__(self, options: Namespace): self.trainingLog: data.TrainingLog = data.TrainingLog() self.trainingEpoch: int = 0 + self.whileTrainingNIfTIsaved: bool = False + self.whileTrainingValidationNIfTIsaved: bool = False self.f_outputPost: Compose self.f_labelPost: Compose @@ -77,7 +92,12 @@ def __init__(self, options: Namespace): self.testingSpace: data.LoaderCache def loaderCache_create( - self, fileList: list[dict[str, str]], transforms: Compose, batch_size: int = 2 + self, + fileList: list[dict[str, str]], + transforms: Compose, + batch_size: int = 2, + title: str = "", + saveExemplar: Path = Path(""), ) -> data.LoaderCache: """ ## Define CacheDataset and DataLoader for training and validation @@ -95,6 +115,15 @@ def loaderCache_create( NB: Parameterize all params!! """ + + NIfTIfile: str = dictFirstKeyValue_getFromList(fileList) + if not saveExemplar.parts: + shutil.copy(str(NIfTIfile), str(saveExemplar)) + if NIfTIfile is not None: + nifti: nib.nifti1 | nib.nifti2 = nib.load(NIfTIfile) + if len(title): + print(f"{title}", end="") + print(f"shape: '{NIfTIfile}': {nifti.shape}") ds: CacheDataset = CacheDataset( data=fileList, transform=transforms, cache_rate=1.0, num_workers=4 ) @@ -125,10 +154,18 @@ def trainingTransformsAndSpace_setup(self) -> bool: return False self.trainingSpace = self.loaderCache_create( - self.trainingFileSet, trainingTransforms + self.trainingFileSet, + trainingTransforms, + 2, + "training set exemplar: ", + self.trainingParams.preTrainingIO / "input.nii.gz", ) self.validationSpace = self.loaderCache_create( - self.validationFileSet, validationTransforms, 1 + self.validationFileSet, + validationTransforms, + 1, + "validation set exemplar: ", + self.trainingParams.preTrainingIO / "validation.nii.gz", ) return setupOK @@ -180,6 +217,43 @@ def evalAndCorrect(self) -> float: self.network.optimizer.step() return f_loss.item() + def metaTensor_toNIfTI(self, metaTensor: MetaTensor, savefile: Path): + singleVolume: np.ndarray = metaTensor[0, 0].cpu().numpy() + affine: np.ndarray = np.eye(4) + niftiVolume: Nifti1Image = Nifti1Image(singleVolume, affine) + nib.save(niftiVolume, savefile) + pass + + def sample_showInfo( + self, + sample: int, + tensor: list[MetaTensor | torch.Tensor], + prefix: list[str], + saveto: list[Path], + saveVolumes: bool = True, + ): + if sample != 1: + return + pudb.set_trace() + for T, txt, savefile in zip(tensor, prefix, saveto): + print(f"{txt} shape: {T.shape}") + if self.trainingParams.options.logTrainingTransformVols and saveVolumes: + print(f"{txt} save: {savefile}") + self.metaTensor_toNIfTI(T, savefile) + + def sample_showSummary( + self, sample: int, sample_loss: float, trainingSpace: data.LoaderCache + ): + if ( + trainingSpace.cache is not None + and trainingSpace.loader.batch_size is not None + ): + print( + f" training run {sample:02}/" + f"{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, " + f"sample loss: {sample_loss:.4f}" + ) + def train_overSampleSpace_retLoss(self, trainingSpace: data.LoaderCache) -> float: sample: int = 0 sample_loss: float = 0.0 @@ -190,38 +264,32 @@ def train_overSampleSpace_retLoss(self, trainingSpace: data.LoaderCache) -> floa trainingInstance["image"].to(self.network.device), trainingInstance["label"].to(self.network.device), ) - if sample == 1: - print(f"training image shape: {self.input.shape}") - print(f"training label shape: {self.target.shape}") + self.sample_showInfo( + sample, + [self.input, self.target], + ["training image", "training label"], + [ + self.trainingParams.whileTrainingIO / "input.nii.gz", + self.trainingParams.whileTrainingIO / "label.nii.gz", + ], + not self.whileTrainingNIfTIsaved, + ) + self.whileTrainingNIfTIsaved = True sample_loss = self.evalAndCorrect() total_loss += sample_loss - if ( - trainingSpace.cache is not None - and trainingSpace.loader.batch_size is not None - ): - print( - f" training run {sample:02}/" - f"{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, " - f"sample loss: {sample_loss:.4f}" - ) + self.sample_showSummary(sample, sample_loss, trainingSpace) total_loss /= sample return total_loss - def train( - self, - trainingSpace: data.LoaderCache | None = None, - validationSpace: data.LoaderCache | None = None, - ): + def train(self, useModelFile: Path | None = None): self.f_outputPost = transforms.transforms_build( [transforms.f_AsDiscreteArgMax()] ) self.f_labelPost = transforms.transforms_build([transforms.f_AsDiscrete()]) self.trainingEpoch = 0 epoch_loss: float = 0.0 - if trainingSpace: - self.trainingSpace = trainingSpace - if validationSpace: - self.validationSpace = validationSpace + if useModelFile is not None: + self.network.model.load_state_dict(torch.load(str(useModelFile))) for self.trainingEpoch in range(self.trainingParams.max_epochs): print("-" * 10) print( @@ -334,9 +402,19 @@ def slidingWindowInference_do( input, roi_size, sw_batch_size, self.network.model ) ) - if index == 1: - print(f"inference input shape: {input.shape}") - print(f"inference output shape: {outputRaw.shape}") + sample["input"] = input.cpu() + sample["output"] = outputRaw.cpu() + self.sample_showInfo( + index, + [input, outputRaw], + ["validation inference input", "validation inference output"], + [ + self.trainingParams.whileTrainingValidation / "input.nii.gz", + self.trainingParams.whileTrainingValidation / "output.nii.gz", + ], + not self.whileTrainingValidationNIfTIsaved, + ) + self.whileTrainingValidationNIfTIsaved = True if f_callback is not None: metric = f_callback(sample, inferSpace, index, outputRaw) return metric @@ -353,7 +431,7 @@ def plot_bestModel( sample, result, str(index), - self.trainingParams.outputDir / "validation" / f"bestModel-val-{index}.png", + self.trainingParams.postTrainingValidation / f"bestModel-val-{index}.png", ) return 0.0