From 6eeadb33f909ba313f69bc261643117bf3119727 Mon Sep 17 00:00:00 2001 From: Rudolph Pienaar Date: Thu, 25 Apr 2024 10:14:44 -0400 Subject: [PATCH] Add more behaviour (continuous training) --- spleenseg/spleenseg.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/spleenseg/spleenseg.py b/spleenseg/spleenseg.py index 36fffcd..41696d9 100644 --- a/spleenseg/spleenseg.py +++ b/spleenseg/spleenseg.py @@ -17,6 +17,7 @@ from typing import Any, Optional, Callable from spleenseg.core import neuralnet +from spleenseg.models.data import TrainingParams from spleenseg.transforms import transforms from spleenseg.plotting import plotting import warnings @@ -56,6 +57,12 @@ default="training", help="mode of behaviour: training or inference", ) +parser.add_argument( + "--logTrainingTransformVols", + default=False, + action="store_true", + help="If specified, save training tranformations as NIfTI volumes", +) parser.add_argument( "--useModel", type=str, @@ -171,12 +178,21 @@ def envDetail_print(options: Namespace, **kwargs): def env_outputDirsMake(options: Namespace) -> None: if "training" in options.mode: - Path(Path(options.outputdir) / "training").mkdir(parents=True, exist_ok=True) - Path(Path(options.outputdir) / "validation").mkdir(parents=True, exist_ok=True) + params: TrainingParams = TrainingParams(options) + params.whileTrainingIO.mkdir(parents=True, exist_ok=True) + params.whileTrainingValidation.mkdir(parents=True, exist_ok=True) + params.postTrainingValidation.mkdir(parents=True, exist_ok=True) if "inference" in options.mode: Path(Path(options.outputdir) / "inference").mkdir(parents=True, exist_ok=True) +def modelFile_inputdirGet(options: Namespace) -> Path: + modelFile: Path = Path(Path(options.inputdir) / options.useModel) + if not modelFile.exists(): + raise FileNotFoundError(f"The model '{modelFile}' does not exist.") + return modelFile + + def training_do(neuralNet: neuralnet.NeuralNet, options: Namespace) -> bool: trainingOK: bool = True @@ -189,6 +205,8 @@ def training_do(neuralNet: neuralnet.NeuralNet, options: Namespace) -> bool: if options.mode == "training": neuralNet.train() + if options.mode == "trainingContinue": + neuralNet.train(modelFile_inputdirGet(options)) plotting.plot_trainingMetrics( neuralNet.trainingLog, @@ -206,11 +224,7 @@ def inference_do(neuralNet: neuralnet.NeuralNet, options: Namespace) -> bool: inferenceOK: bool = True neuralNet.testingFileSet = testingData_prep(options) - modelFile: Path = Path(Path(options.inputdir) / options.useModel) - if not modelFile.exists(): - raise FileNotFoundError(f"The model '{modelFile}' does not exist.") - - neuralNet.infer_usingModel(modelFile) + neuralNet.infer_usingModel(modelFile_inputdirGet(options)) return inferenceOK