diff --git a/src/qusi/internal/train_session.py b/src/qusi/internal/train_session.py index b3c70d6..d96d884 100644 --- a/src/qusi/internal/train_session.py +++ b/src/qusi/internal/train_session.py @@ -15,6 +15,7 @@ from qusi.internal.logging import set_up_default_logger, get_metric_name from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration from qusi.internal.train_logging_configuration import TrainLoggingConfiguration +from qusi.internal.train_system_configuration import TrainSystemConfiguration from qusi.internal.wandb_liaison import wandb_commit, wandb_init, wandb_log logger = logging.getLogger(__name__) @@ -29,6 +30,7 @@ def train_session( metric_functions: list[Module] | None = None, *, hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, + system_configuration: TrainSystemConfiguration | None = None, logging_configuration: TrainLoggingConfiguration | None = None, ) -> None: """ @@ -40,7 +42,8 @@ def train_session( :param optimizer: The optimizer to be used during training. :param loss_function: The loss function to train the model on. :param metric_functions: A list of metric functions to record during the training process. - :param hyperparameter_configuration: The configuration of the hyperparameters + :param hyperparameter_configuration: The configuration of the hyperparameters. + :param system_configuration: The configuration of the system. :param logging_configuration: The configuration of the logging. """ if hyperparameter_configuration is None: @@ -69,7 +72,7 @@ def train_session( prefetch_factor = None persistent_workers = False else: - workers_per_dataloader = 10 + workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process prefetch_factor = 10 persistent_workers = True train_dataloader = DataLoader(