diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py index e2f4594..9d0da4e 100644 --- a/src/qusi/internal/lightning_train_session.py +++ b/src/qusi/internal/lightning_train_session.py @@ -2,6 +2,7 @@ import datetime import logging +import math from pathlib import Path from warnings import warn @@ -75,8 +76,29 @@ def train_session( logging_metrics = [BinaryAccuracy(), BinaryAUROC()] set_up_default_logger() + + sessions_directory_path = Path(f'sessions') + session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}' + sessions_directory_path.mkdir(exist_ok=True, parents=True) + loggers = [ + CSVLogger(save_dir=sessions_directory_path, name=session_name), + WandbLogger(save_dir=sessions_directory_path, name=session_name)] + trainer = lightning.Trainer( + max_epochs=hyperparameter_configuration.cycles, + limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, + limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, + log_every_n_steps=0, + accelerator=system_configuration.accelerator, + logger=loggers, + ) + train_dataset = InterleavedDataset.new(*train_datasets) workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process + + local_batch_size = round(hyperparameter_configuration.batch_size / trainer.world_size) + if local_batch_size == 0: + local_batch_size = 1 + if workers_per_dataloader == 0: prefetch_factor = None persistent_workers = False @@ -85,7 +107,7 @@ def train_session( persistent_workers = True train_dataloader = DataLoader( train_dataset, - batch_size=hyperparameter_configuration.batch_size, + batch_size=local_batch_size, pin_memory=True, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, @@ -95,7 +117,7 @@ def train_session( for validation_dataset in validation_datasets: validation_dataloader = DataLoader( validation_dataset, - batch_size=hyperparameter_configuration.batch_size, + batch_size=local_batch_size, pin_memory=True, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, @@ -105,18 +127,4 @@ def train_session( lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, logging_metrics=logging_metrics) - sessions_directory_path = Path(f'sessions') - session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}' - sessions_directory_path.mkdir(exist_ok=True, parents=True) - loggers = [ - CSVLogger(save_dir=sessions_directory_path, name=session_name), - WandbLogger(save_dir=sessions_directory_path, name=session_name)] - trainer = lightning.Trainer( - max_epochs=hyperparameter_configuration.cycles, - limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, - limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, - log_every_n_steps=0, - accelerator=system_configuration.accelerator, - logger=loggers, - ) trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)