Skip to content

Commit

Permalink
Start improving telemetry tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 25, 2024
1 parent 2066bf1 commit ac4824c
Showing 1 changed file with 106 additions and 28 deletions.
134 changes: 106 additions & 28 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@

from typing import Any, Sequence
import numpy as np
import nibabel as nib
from nibabel.nifti1 import Nifti1Image

from spleenseg.transforms import transforms
from spleenseg.models import data
from spleenseg.plotting import plotting
import pudb
import shutil


def dictFirstKeyValue_getFromList(l: list[dict[str, Any]], i: int = 0) -> Any:
val: Any = None
if len(l):
if l[i]:
key: str = list(l[i].keys())[0]
val = l[i][key]
return val


def tensor_desc(
Expand Down Expand Up @@ -64,6 +77,8 @@ def __init__(self, options: Namespace):

self.trainingLog: data.TrainingLog = data.TrainingLog()
self.trainingEpoch: int = 0
self.whileTrainingNIfTIsaved: bool = False
self.whileTrainingValidationNIfTIsaved: bool = False

self.f_outputPost: Compose
self.f_labelPost: Compose
Expand All @@ -77,7 +92,12 @@ def __init__(self, options: Namespace):
self.testingSpace: data.LoaderCache

def loaderCache_create(
self, fileList: list[dict[str, str]], transforms: Compose, batch_size: int = 2
self,
fileList: list[dict[str, str]],
transforms: Compose,
batch_size: int = 2,
title: str = "",
saveExemplar: Path = Path(""),
) -> data.LoaderCache:
"""
## Define CacheDataset and DataLoader for training and validation
Expand All @@ -95,6 +115,15 @@ def loaderCache_create(
NB: Parameterize all params!!
"""

NIfTIfile: str = dictFirstKeyValue_getFromList(fileList)
if not saveExemplar.parts:
shutil.copy(str(NIfTIfile), str(saveExemplar))
if NIfTIfile is not None:
nifti: nib.nifti1 | nib.nifti2 = nib.load(NIfTIfile)
if len(title):
print(f"{title}", end="")
print(f"shape: '{NIfTIfile}': {nifti.shape}")
ds: CacheDataset = CacheDataset(
data=fileList, transform=transforms, cache_rate=1.0, num_workers=4
)
Expand Down Expand Up @@ -125,10 +154,18 @@ def trainingTransformsAndSpace_setup(self) -> bool:
return False

self.trainingSpace = self.loaderCache_create(
self.trainingFileSet, trainingTransforms
self.trainingFileSet,
trainingTransforms,
2,
"training set exemplar: ",
self.trainingParams.preTrainingIO / "input.nii.gz",
)
self.validationSpace = self.loaderCache_create(
self.validationFileSet, validationTransforms, 1
self.validationFileSet,
validationTransforms,
1,
"validation set exemplar: ",
self.trainingParams.preTrainingIO / "validation.nii.gz",
)
return setupOK

Expand Down Expand Up @@ -180,6 +217,43 @@ def evalAndCorrect(self) -> float:
self.network.optimizer.step()
return f_loss.item()

def metaTensor_toNIfTI(self, metaTensor: MetaTensor, savefile: Path):
singleVolume: np.ndarray = metaTensor[0, 0].cpu().numpy()
affine: np.ndarray = np.eye(4)
niftiVolume: Nifti1Image = Nifti1Image(singleVolume, affine)
nib.save(niftiVolume, savefile)
pass

def sample_showInfo(
self,
sample: int,
tensor: list[MetaTensor | torch.Tensor],
prefix: list[str],
saveto: list[Path],
saveVolumes: bool = True,
):
if sample != 1:
return
pudb.set_trace()
for T, txt, savefile in zip(tensor, prefix, saveto):
print(f"{txt} shape: {T.shape}")
if self.trainingParams.options.logTrainingTransformVols and saveVolumes:
print(f"{txt} save: {savefile}")
self.metaTensor_toNIfTI(T, savefile)

def sample_showSummary(
self, sample: int, sample_loss: float, trainingSpace: data.LoaderCache
):
if (
trainingSpace.cache is not None
and trainingSpace.loader.batch_size is not None
):
print(
f" training run {sample:02}/"
f"{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, "
f"sample loss: {sample_loss:.4f}"
)

def train_overSampleSpace_retLoss(self, trainingSpace: data.LoaderCache) -> float:
sample: int = 0
sample_loss: float = 0.0
Expand All @@ -190,38 +264,32 @@ def train_overSampleSpace_retLoss(self, trainingSpace: data.LoaderCache) -> floa
trainingInstance["image"].to(self.network.device),
trainingInstance["label"].to(self.network.device),
)
if sample == 1:
print(f"training image shape: {self.input.shape}")
print(f"training label shape: {self.target.shape}")
self.sample_showInfo(
sample,
[self.input, self.target],
["training image", "training label"],
[
self.trainingParams.whileTrainingIO / "input.nii.gz",
self.trainingParams.whileTrainingIO / "label.nii.gz",
],
not self.whileTrainingNIfTIsaved,
)
self.whileTrainingNIfTIsaved = True
sample_loss = self.evalAndCorrect()
total_loss += sample_loss
if (
trainingSpace.cache is not None
and trainingSpace.loader.batch_size is not None
):
print(
f" training run {sample:02}/"
f"{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, "
f"sample loss: {sample_loss:.4f}"
)
self.sample_showSummary(sample, sample_loss, trainingSpace)
total_loss /= sample
return total_loss

def train(
self,
trainingSpace: data.LoaderCache | None = None,
validationSpace: data.LoaderCache | None = None,
):
def train(self, useModelFile: Path | None = None):
self.f_outputPost = transforms.transforms_build(
[transforms.f_AsDiscreteArgMax()]
)
self.f_labelPost = transforms.transforms_build([transforms.f_AsDiscrete()])
self.trainingEpoch = 0
epoch_loss: float = 0.0
if trainingSpace:
self.trainingSpace = trainingSpace
if validationSpace:
self.validationSpace = validationSpace
if useModelFile is not None:
self.network.model.load_state_dict(torch.load(str(useModelFile)))
for self.trainingEpoch in range(self.trainingParams.max_epochs):
print("-" * 10)
print(
Expand Down Expand Up @@ -334,9 +402,19 @@ def slidingWindowInference_do(
input, roi_size, sw_batch_size, self.network.model
)
)
if index == 1:
print(f"inference input shape: {input.shape}")
print(f"inference output shape: {outputRaw.shape}")
sample["input"] = input.cpu()
sample["output"] = outputRaw.cpu()
self.sample_showInfo(
index,
[input, outputRaw],
["validation inference input", "validation inference output"],
[
self.trainingParams.whileTrainingValidation / "input.nii.gz",
self.trainingParams.whileTrainingValidation / "output.nii.gz",
],
not self.whileTrainingValidationNIfTIsaved,
)
self.whileTrainingValidationNIfTIsaved = True
if f_callback is not None:
metric = f_callback(sample, inferSpace, index, outputRaw)
return metric
Expand All @@ -353,7 +431,7 @@ def plot_bestModel(
sample,
result,
str(index),
self.trainingParams.outputDir / "validation" / f"bestModel-val-{index}.png",
self.trainingParams.postTrainingValidation / f"bestModel-val-{index}.png",
)
return 0.0

Expand Down

0 comments on commit ac4824c

Please sign in to comment.