Skip to content

Commit

Permalink
Cb wandb (#453)
Browse files Browse the repository at this point in the history
* add option to track training parameters with MLflow or WandB
  • Loading branch information
camillebrianceau authored Sep 12, 2023
1 parent bbb1954 commit d0c7bc4
Show file tree
Hide file tree
Showing 12 changed files with 2,416 additions and 1,305 deletions.
83 changes: 0 additions & 83 deletions clinicadl/mlflow_test.py

This file was deleted.

1 change: 1 addition & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ amp = false
seed = 0
deterministic = false
compensation = "memory" # Only used if deterministic = true
track_exp = ""

[Transfer_learning]
transfer_path = ""
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@train_option.tolerance
@train_option.accumulation_steps
@train_option.profiler
@train_option.track_exp
# transfer learning
@train_option.transfer_path
@train_option.transfer_selection_metric
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@train_option.tolerance
@train_option.accumulation_steps
@train_option.profiler
@train_option.track_exp
# transfer learning
@train_option.transfer_path
@train_option.transfer_selection_metric
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@train_option.tolerance
@train_option.accumulation_steps
@train_option.profiler
@train_option.track_exp
# transfer learning
@train_option.transfer_path
@train_option.transfer_selection_metric
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"patience",
"profiler",
"tolerance",
"track_exp",
"transfer_path",
"transfer_selection_metric",
"weight_decay",
Expand Down
12 changes: 12 additions & 0 deletions clinicadl/utils/cli_param/train_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@
help="Use `--profiler` to enable Pytorch profiler for the first 30 steps after a short warmup. "
"It will make an execution trace and some statistics about the CPU and GPU usage.",
)
track_exp = cli_param.option_group.optimization_group.option(
"--track_exp",
"-te",
type=click.Choice(
[
"wandb",
"mlflow",
"",
]
),
help="Use `--track_exp` to enable wandb/mlflow to track the metric (loss, accuracy, etc...) during the training.",
)
# transfer learning
transfer_path = cli_param.option_group.transfer_learning_group.option(
"-tp",
Expand Down
34 changes: 33 additions & 1 deletion clinicadl/utils/maps_manager/maps_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def _train(
logger.info(f"Criterion for {self.network_task} is {criterion}")

optimizer = self._init_optimizer(model, split=split, resume=resume)
logger.debug(f"Optimizer used for training is optimizer")
logger.debug(f"Optimizer used for training is {optimizer}")

model.train()
train_loader.dataset.train()
Expand All @@ -856,6 +856,16 @@ def _train(
scaler = GradScaler(enabled=self.amp)
profiler = self._init_profiler()

if self.parameters["track_exp"] == "wandb":
from clinicadl.utils.tracking_exp import WandB_handler

run = WandB_handler(split, self.parameters, self.maps_path.name)

if self.parameters["track_exp"] == "mlflow":
from clinicadl.utils.tracking_exp import Mlflow_handler

run = Mlflow_handler(split, self.parameters, self.maps_path.name)

while epoch < self.epochs and not early_stopping.step(metrics_valid["loss"]):
logger.info(f"Beginning epoch {epoch}.")

Expand Down Expand Up @@ -956,6 +966,23 @@ def _train(
f"{self.mode} level validation loss is {metrics_valid['loss']} "
f"at the end of iteration {i}"
)
if self.track_exp == "wandb":
run.log_metrics(
run._wandb,
self.track_exp,
self.network_task,
metrics_train,
metrics_valid,
)

if self.track_exp == "mlflow":
run.log_metrics(
run._mlflow,
self.track_exp,
self.network_task,
metrics_train,
metrics_valid,
)

# Save checkpoints and best models
best_dict = retain_best.step(metrics_valid)
Expand All @@ -981,6 +1008,11 @@ def _train(
)

epoch += 1
if self.parameters["track_exp"] == "mlflow":
run._mlflow.end_run()

if self.parameters["track_exp"] == "wandb":
run._wandb.finish()

self._test_loader(
train_loader,
Expand Down
137 changes: 137 additions & 0 deletions clinicadl/utils/tracking_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Training Callbacks for training monitoring integrated in `pythae` (inspired from
https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_callback.py)"""

import importlib
import logging
from copy import copy
from pathlib import Path

logger = logging.getLogger(__name__)


def wandb_is_available():
return importlib.util.find_spec("wandb")


def mlflow_is_available():
return importlib.util.find_spec("mlflow") is not None


class Tracker:
"""
Base class to track the metrics during training depending on the network task.
"""

def __init__(self):
pass

def log_metrics(
self,
tracker,
track_exp: bool = False,
network_task: str = "classification",
metrics_train: list = [],
metrics_valid: list = [],
):
metrics_dict = {}
if network_task == "classification":
metrics_dict = {
"loss_train": metrics_train["loss"],
"accuracy_train": metrics_train["accuracy"],
"sensitivity_train": metrics_train["sensitivity"],
"accuracy_train": metrics_train["accuracy"],
"specificity_train": metrics_train["specificity"],
"PPV_train": metrics_train["PPV"],
"NPV_train": metrics_train["NPV"],
"BA_train": metrics_train["BA"],
"loss_valid": metrics_valid["loss"],
"accuracy_valid": metrics_valid["accuracy"],
"sensitivity_valid": metrics_valid["sensitivity"],
"accuracy_valid": metrics_valid["accuracy"],
"specificity_valid": metrics_valid["specificity"],
"PPV_valid": metrics_valid["PPV"],
"NPV_valid": metrics_valid["NPV"],
"BA_valid": metrics_valid["BA"],
}
elif network_task == "reconstruction":
metrics_dict = {
"loss_train": metrics_train["loss"],
"MSE_train": metrics_train["MSE"],
"MAE_train": metrics_train["MAE"],
"PSNR_train": metrics_train["PSNR"],
"SSIM_train": metrics_train["SSIM"],
"loss_valid": metrics_valid["loss"],
"MSE_valid": metrics_valid["MSE"],
"MAE_valid": metrics_valid["MAE"],
"PSNR_valid": metrics_valid["PSNR"],
"SSIM_valid": metrics_valid["SSIM"],
}
elif network_task == "regression":
metrics_dict = {
"loss_train": metrics_train["loss"],
"MSE_train": metrics_train["MSE"],
"MAE_train": metrics_train["MAE"],
"loss_valid": metrics_valid["loss"],
"MSE_valid": metrics_valid["MSE"],
"MAE_valid": metrics_valid["MAE"],
}

if track_exp == "wandb":
tracker.log(metrics_dict)
return metrics_dict
elif track_exp == "mlflow":
tracker.log_metrics(metrics_dict)
return metrics_dict


class WandB_handler(Tracker):
def __init__(self, split: str, config: dict, maps_name: str):
if not wandb_is_available():
raise ModuleNotFoundError(
"`wandb` package must be installed. Run `pip install wandb`"
)
else:
import wandb

self._wandb = wandb

self._wandb.init(
project="ClinicaDL",
entity="clinicadl",
config=config,
save_code=True,
group=maps_name,
mode="online",
name=f"split-{split}",
reinit=True,
)


class Mlflow_handler(Tracker):
def __init__(self, split: str, config: dict, maps_name: str):
if not mlflow_is_available():
raise ModuleNotFoundError(
"`mlflow` package must be installed. Run `pip install mlflow`"
)
else:
import mlflow

self._mlflow = mlflow

try:
experiment_id = self._mlflow.create_experiment(
f"clinicadl-{maps_name}",
artifact_location=Path.cwd().joinpath("mlruns").as_uri(),
)

except mlflow.exceptions.MlflowException:
self._mlflow.set_experiment(maps_name)

self._mlflow.start_run(experiment_id=experiment_id, run_name=f"split-{split}")
self._mlflow.autolog()
config_bis = copy(config)
for cle, valeur in config.items():
if cle == "preprocessing_dict":
del config_bis[cle]
config = config_bis
self._mlflow.log_params(config)
3 changes: 3 additions & 0 deletions docs/Train/Introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ Options shared for all values of `NETWORK_TASK` are organized in groups:
- `--transfer_selection_metric` (str) is the transfer learning selection metric.
- `--nb_unfrozen_layer` (int) is the number of layer that will be retrain during training. For example, if it is 2, the last two layers of the model will not be freezed.
See [Implementation details](Details.md/#transfer-learning) for more information about transfer learning.
- **Track an experiment**
- `--track_exp` (str) is the name of the experiment tracker you want to use. Must be chosen between `wandb` (Weight & Biases) and `mlflow`. As mlflow and W&B are not ClinicaDL dependencies, you must install the one chosen on your own (by running `pip install wandb/mlflow`).
For more information, check out the documentation of [W&B](https://docs.wandb.ai) or [Mlflow](https://mlflow.org/docs/latest/index.html)

<!---
!!! tip
Expand Down
Loading

0 comments on commit d0c7bc4

Please sign in to comment.