Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 17, 2024
1 parent 9ebe7a6 commit 857355f
Showing 1 changed file with 60 additions and 31 deletions.
91 changes: 60 additions & 31 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, options: Namespace):
self.modelPth = Path(options.outputdir) / "model.pth"
self.modelONNX = Path(options.outputdir) / "model.onnx"
self.determinismSeed = self.options.determinismSeed
set_determinism(self.determinismSeed)
set_determinism(seed=self.determinismSeed)


@dataclass
Expand Down Expand Up @@ -103,10 +103,34 @@ def __init__(self, options: Namespace):
self.options = options
if options is not None:
self.device = torch.device(self.options.device)
torch.manual_seed(42)
self.model = self.model.to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), 1e-4)


def tensor_desc(
T: torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor], **kwargs
) -> torch.Tensor:
strAs: str = "meanstd"
v1: float = 0.0
v2: float = 0.0
tensor: torch.Tensor = torch.Tensor([v1, v2])
for k, v in kwargs.items():
if k.lower() == "desc":
strAs = v
T = torch.as_tensor(T)
match strAs:
case "meanstd":
tensor = torch.Tensor([T.mean().item(), T.std().item()])
case "l1l2":
tensor = torch.Tensor([T.abs().sum().item(), T.pow(2).sum().sqrt().item()])
case "minmax":
tensor = torch.Tensor([T.min().item(), T.max().item()])
case "simplified":
tensor = T.mean(dim=(1, 2), keepdim=True)
return tensor


class NeuralNet:
def __init__(self, options: Namespace):
self.network: ModelParams = ModelParams(options)
Expand All @@ -118,8 +142,8 @@ def __init__(self, options: Namespace):

self.trainingLog: TrainingLog = TrainingLog()

self.f_outputPost: Compose | None = None
self.f_labelPost: Compose | None = None
self.f_outputPost: Compose
self.f_labelPost: Compose

self.trainingSpace: LoaderCache
self.validationSpace: LoaderCache
Expand Down Expand Up @@ -150,9 +174,12 @@ def loaderCache_create(

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
loader: DataLoader = DataLoader(
ds, batch_size=batch_size, shuffle=True, num_workers=4
)
loader: DataLoader
if batch_size == 2:
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
else:
loader = DataLoader(ds, batch_size=batch_size, num_workers=4)

loaderCache: LoaderCache = LoaderCache(cache=ds, loader=loader)
return loaderCache

Expand All @@ -174,25 +201,23 @@ def tensor_assign(
return T

def feedForward(
self, input: Optional[torch.Tensor] = None
self,
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
"""
Simply run the self.input and generate an output
Simply run the self.input and generate/return/store an output
"""
if input:
self.input = input
# print(tensor_desc(self.input))
self.output = self.network.model(self.input)
# print(tensor_desc(self.output))
return self.output

def evalAndCorrect(
self, input: torch.Tensor, target: Optional[torch.Tensor] = None
) -> float:
def evalAndCorrect(self) -> float:
self.network.optimizer.zero_grad()
if target is not None:
self.target = target

self.feedForward()
f_loss: torch.Tensor = self.network.fn_loss(
torch.as_tensor(self.feedForward(input)), torch.as_tensor(self.target)
torch.as_tensor(self.output),
torch.as_tensor(self.target),
)
f_loss.backward()
self.network.optimizer.step()
Expand All @@ -208,14 +233,14 @@ def train_overSampleSpace_retLoss(self, trainingSpace: LoaderCache) -> float:
trainingInstance["image"].to(self.network.device),
trainingInstance["label"].to(self.network.device),
)
sample_loss = self.evalAndCorrect(self.input, self.target)
sample_loss = self.evalAndCorrect()
total_loss += sample_loss
if (
trainingSpace.cache is not None
and trainingSpace.loader.batch_size is not None
):
print(
f"{sample}/{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, "
f"{sample:02}/{len(trainingSpace.cache) // trainingSpace.loader.batch_size}, "
f"sample loss: {sample_loss:.4f}"
)
total_loss /= sample
Expand All @@ -232,19 +257,17 @@ def train(
self.trainingSpace = trainingSpace
if validationSpace:
self.validationSpace = validationSpace
self.network.model.train()
for epoch in range(self.training.max_epochs):
print("-" * 10)
print(f"epoch {epoch+1:03} / {self.training.max_epochs}")
self.network.model.train()
epoch_loss = self.train_overSampleSpace_retLoss(self.trainingSpace)
print(f"epoch {epoch+1:03}, average loss: {epoch_loss:.4f}")
self.trainingLog.loss_per_epoch.append(epoch_loss)
if (epoch + 1) % self.training.val_interval == 0:
self.slidingWindowInference_do(self.validationSpace, epoch)

self.slidingWindowInference_do(self.validationSpace)
print(f"current epoch: {epoch + 1}, current mean dice")

def inference_metricsProcess(self):
def inference_metricsProcess(self, epoch: int) -> float:
metric: float = self.network.dice_metric.aggregate().item()
self.trainingLog.metric_per_epoch.append(metric)
self.network.dice_metric.reset()
Expand All @@ -253,10 +276,17 @@ def inference_metricsProcess(self):
self.training.best_metric_epoch = epoch + 1
torch.save(self.network.model.state_dict(), str(self.training.modelPth))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1}, current mean dice: {metric:.4f}"
f"\nbest mean dice: {self.training.best_metric:.4f}"
f"at epoch: {self.training.best_metric_epoch}"
)
return metric

def slidingWindowInference_do(
self, inferCache: LoaderCache, truthCache: LoaderCache | None = None
):
self, inferCache: LoaderCache, epoch: int | None = None
) -> float:
metric: float = 0.0
self.network.model.eval()
with torch.no_grad():
for sample in inferCache.loader:
Expand All @@ -271,13 +301,12 @@ def slidingWindowInference_do(
outputPostProc = [
self.f_outputPost(i) for i in decollate_batch(outputRaw)
]
if truthCache:
if epoch is not None:
labelTruth: torch.Tensor = sample["label"].to(self.network.device)
labelPostProc = [
self.f_labelPost(i) for i in decollate_batch(labelTruth)
]
self.network.dice_metric(
y_pred=training.val_outputs, y=training.val_labels
)
if truthCache:
self.inference_metricsProcess()
self.network.dice_metric(y_pred=outputPostProc, y=labelPostProc)
if epoch is not None:
metric = self.inference_metricsProcess(epoch)
return metric

0 comments on commit 857355f

Please sign in to comment.