Skip to content

Commit

Permalink
Add the ability to pass in a system configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Jun 18, 2024
1 parent e5e431e commit 8f8b5cc
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from qusi.internal.logging import set_up_default_logger, get_metric_name
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.internal.train_logging_configuration import TrainLoggingConfiguration
from qusi.internal.train_system_configuration import TrainSystemConfiguration
from qusi.internal.wandb_liaison import wandb_commit, wandb_init, wandb_log

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +30,7 @@ def train_session(
metric_functions: list[Module] | None = None,
*,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
system_configuration: TrainSystemConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
) -> None:
"""
Expand All @@ -40,7 +42,8 @@ def train_session(
:param optimizer: The optimizer to be used during training.
:param loss_function: The loss function to train the model on.
:param metric_functions: A list of metric functions to record during the training process.
:param hyperparameter_configuration: The configuration of the hyperparameters
:param hyperparameter_configuration: The configuration of the hyperparameters.
:param system_configuration: The configuration of the system.
:param logging_configuration: The configuration of the logging.
"""
if hyperparameter_configuration is None:
Expand Down Expand Up @@ -69,7 +72,7 @@ def train_session(
prefetch_factor = None
persistent_workers = False
else:
workers_per_dataloader = 10
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process
prefetch_factor = 10
persistent_workers = True
train_dataloader = DataLoader(
Expand Down

0 comments on commit 8f8b5cc

Please sign in to comment.