diff --git a/src/qusi/train_session.py b/src/qusi/train_session.py index f627f54b..fa584781 100644 --- a/src/qusi/train_session.py +++ b/src/qusi/train_session.py @@ -26,11 +26,13 @@ def train_session( validation_datasets: list[LightCurveDataset], model: Module, hyperparameter_configuration: TrainHyperparameterConfiguration | None, + logging_configuration: TrainLoggingConfiguration | None, ): if hyperparameter_configuration is None: hyperparameter_configuration = TrainHyperparameterConfiguration.new() + if logging_configuration is None: + logging_configuration = TrainLoggingConfiguration.new() set_up_default_logger() - logging_configuration = TrainLoggingConfiguration.new() wandb_init( process_rank=0, project=logging_configuration.wandb_project, @@ -78,7 +80,8 @@ def train_session( metric_functions = [BinaryAccuracy()] optimizer = AdamW(model.parameters()) metric_functions: list[Module] = [ - metric_function.to(device, non_blocking=True) for metric_function in metric_functions + metric_function.to(device, non_blocking=True) + for metric_function in metric_functions ] for _cycle_index in range(hyperparameter_configuration.cycles): train_phase( @@ -103,7 +106,15 @@ def train_session( wandb_commit(process_rank=0) -def train_phase(dataloader, model, loss_function, metric_functions: list[Module], optimizer, steps, device): +def train_phase( + dataloader, + model, + loss_function, + metric_functions: list[Module], + optimizer, + steps, + device, +): model.train() total_loss = 0 metric_totals = np.zeros(shape=[len(metric_functions)]) @@ -112,7 +123,9 @@ def train_phase(dataloader, model, loss_function, metric_functions: list[Module] # TODO: The conversion to float32 probably shouldn't be here, but the default collate_fn seems to be converting # to float64. Probably should override the default collate. targets_on_device = targets.to(torch.float32).to(device, non_blocking=True) - input_features_on_device = input_features.to(torch.float32).to(device, non_blocking=True) + input_features_on_device = input_features.to(torch.float32).to( + device, non_blocking=True + ) predicted_targets = model(input_features_on_device) loss = loss_function(predicted_targets, targets_on_device) @@ -121,7 +134,10 @@ def train_phase(dataloader, model, loss_function, metric_functions: list[Module] loss.backward() optimizer.step() - loss, current = loss.to(device, non_blocking=True).item(), (batch_index + 1) * len(input_features_on_device) + loss, current = ( + loss.to(device, non_blocking=True).item(), + (batch_index + 1) * len(input_features_on_device), + ) total_loss += loss for metric_function_index, metric_function in enumerate(metric_functions): batch_metric_value = metric_function( @@ -129,13 +145,19 @@ def train_phase(dataloader, model, loss_function, metric_functions: list[Module] ).item() metric_totals[metric_function_index] += batch_metric_value if batch_index % 10 == 0: - logger.info(f"loss: {loss:>7f} [{current:>5d}/{steps * len(input_features_on_device):>5d}]") + logger.info( + f"loss: {loss:>7f} [{current:>5d}/{steps * len(input_features_on_device):>5d}]" + ) if batch_index + 1 >= steps: break wandb_log("loss", total_loss / steps, process_rank=0) cycle_metric_values = metric_totals / steps for metric_function_index, metric_function in enumerate(metric_functions): - wandb_log(f"{get_metric_name(metric_function)}", cycle_metric_values[metric_function_index], process_rank=0) + wandb_log( + f"{get_metric_name(metric_function)}", + cycle_metric_values[metric_function_index], + process_rank=0, + ) def get_metric_name(metric_function): @@ -145,7 +167,9 @@ def get_metric_name(metric_function): return metric_name -def validation_phase(dataloader, model, loss_function, metric_functions: list[Module], steps, device): +def validation_phase( + dataloader, model, loss_function, metric_functions: list[Module], steps, device +): model.eval() validation_loss = 0 metric_totals = np.zeros(shape=[len(metric_functions)]) @@ -153,9 +177,15 @@ def validation_phase(dataloader, model, loss_function, metric_functions: list[Mo with torch.no_grad(): for batch, (input_features, targets) in enumerate(dataloader): targets_on_device = targets.to(torch.float32).to(device, non_blocking=True) - input_features_on_device = input_features.to(torch.float32).to(device, non_blocking=True) + input_features_on_device = input_features.to(torch.float32).to( + device, non_blocking=True + ) predicted_targets = model(input_features_on_device) - validation_loss += loss_function(predicted_targets, targets_on_device).to(device, non_blocking=True).item() + validation_loss += ( + loss_function(predicted_targets, targets_on_device) + .to(device, non_blocking=True) + .item() + ) for metric_function_index, metric_function in enumerate(metric_functions): batch_metric_value = metric_function( predicted_targets.to(device, non_blocking=True), targets_on_device @@ -169,7 +199,11 @@ def validation_phase(dataloader, model, loss_function, metric_functions: list[Mo wandb_log("val_loss", validation_loss, process_rank=0) cycle_metric_values = metric_totals / steps for metric_function_index, metric_function in enumerate(metric_functions): - wandb_log(f"val_{get_metric_name(metric_function)}", cycle_metric_values[metric_function_index], process_rank=0) + wandb_log( + f"val_{get_metric_name(metric_function)}", + cycle_metric_values[metric_function_index], + process_rank=0, + ) def save_model(model: Module, suffix: str, process_rank: int):