Skip to content

Commit

Permalink
Allow train logging configuration to be passed in
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Feb 29, 2024
1 parent 8b44bee commit 066b997
Showing 1 changed file with 45 additions and 11 deletions.
56 changes: 45 additions & 11 deletions src/qusi/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def train_session(
validation_datasets: list[LightCurveDataset],
model: Module,
hyperparameter_configuration: TrainHyperparameterConfiguration | None,
logging_configuration: TrainLoggingConfiguration | None,
):
if hyperparameter_configuration is None:
hyperparameter_configuration = TrainHyperparameterConfiguration.new()
if logging_configuration is None:
logging_configuration = TrainLoggingConfiguration.new()
set_up_default_logger()
logging_configuration = TrainLoggingConfiguration.new()
wandb_init(
process_rank=0,
project=logging_configuration.wandb_project,
Expand Down Expand Up @@ -78,7 +80,8 @@ def train_session(
metric_functions = [BinaryAccuracy()]
optimizer = AdamW(model.parameters())
metric_functions: list[Module] = [
metric_function.to(device, non_blocking=True) for metric_function in metric_functions
metric_function.to(device, non_blocking=True)
for metric_function in metric_functions
]
for _cycle_index in range(hyperparameter_configuration.cycles):
train_phase(
Expand All @@ -103,7 +106,15 @@ def train_session(
wandb_commit(process_rank=0)


def train_phase(dataloader, model, loss_function, metric_functions: list[Module], optimizer, steps, device):
def train_phase(
dataloader,
model,
loss_function,
metric_functions: list[Module],
optimizer,
steps,
device,
):
model.train()
total_loss = 0
metric_totals = np.zeros(shape=[len(metric_functions)])
Expand All @@ -112,7 +123,9 @@ def train_phase(dataloader, model, loss_function, metric_functions: list[Module]
# TODO: The conversion to float32 probably shouldn't be here, but the default collate_fn seems to be converting
# to float64. Probably should override the default collate.
targets_on_device = targets.to(torch.float32).to(device, non_blocking=True)
input_features_on_device = input_features.to(torch.float32).to(device, non_blocking=True)
input_features_on_device = input_features.to(torch.float32).to(
device, non_blocking=True
)
predicted_targets = model(input_features_on_device)
loss = loss_function(predicted_targets, targets_on_device)

Expand All @@ -121,21 +134,30 @@ def train_phase(dataloader, model, loss_function, metric_functions: list[Module]
loss.backward()
optimizer.step()

loss, current = loss.to(device, non_blocking=True).item(), (batch_index + 1) * len(input_features_on_device)
loss, current = (
loss.to(device, non_blocking=True).item(),
(batch_index + 1) * len(input_features_on_device),
)
total_loss += loss
for metric_function_index, metric_function in enumerate(metric_functions):
batch_metric_value = metric_function(
predicted_targets.to(device, non_blocking=True), targets_on_device
).item()
metric_totals[metric_function_index] += batch_metric_value
if batch_index % 10 == 0:
logger.info(f"loss: {loss:>7f} [{current:>5d}/{steps * len(input_features_on_device):>5d}]")
logger.info(
f"loss: {loss:>7f} [{current:>5d}/{steps * len(input_features_on_device):>5d}]"
)
if batch_index + 1 >= steps:
break
wandb_log("loss", total_loss / steps, process_rank=0)
cycle_metric_values = metric_totals / steps
for metric_function_index, metric_function in enumerate(metric_functions):
wandb_log(f"{get_metric_name(metric_function)}", cycle_metric_values[metric_function_index], process_rank=0)
wandb_log(
f"{get_metric_name(metric_function)}",
cycle_metric_values[metric_function_index],
process_rank=0,
)


def get_metric_name(metric_function):
Expand All @@ -145,17 +167,25 @@ def get_metric_name(metric_function):
return metric_name


def validation_phase(dataloader, model, loss_function, metric_functions: list[Module], steps, device):
def validation_phase(
dataloader, model, loss_function, metric_functions: list[Module], steps, device
):
model.eval()
validation_loss = 0
metric_totals = np.zeros(shape=[len(metric_functions)])

with torch.no_grad():
for batch, (input_features, targets) in enumerate(dataloader):
targets_on_device = targets.to(torch.float32).to(device, non_blocking=True)
input_features_on_device = input_features.to(torch.float32).to(device, non_blocking=True)
input_features_on_device = input_features.to(torch.float32).to(
device, non_blocking=True
)
predicted_targets = model(input_features_on_device)
validation_loss += loss_function(predicted_targets, targets_on_device).to(device, non_blocking=True).item()
validation_loss += (
loss_function(predicted_targets, targets_on_device)
.to(device, non_blocking=True)
.item()
)
for metric_function_index, metric_function in enumerate(metric_functions):
batch_metric_value = metric_function(
predicted_targets.to(device, non_blocking=True), targets_on_device
Expand All @@ -169,7 +199,11 @@ def validation_phase(dataloader, model, loss_function, metric_functions: list[Mo
wandb_log("val_loss", validation_loss, process_rank=0)
cycle_metric_values = metric_totals / steps
for metric_function_index, metric_function in enumerate(metric_functions):
wandb_log(f"val_{get_metric_name(metric_function)}", cycle_metric_values[metric_function_index], process_rank=0)
wandb_log(
f"val_{get_metric_name(metric_function)}",
cycle_metric_values[metric_function_index],
process_rank=0,
)


def save_model(model: Module, suffix: str, process_rank: int):
Expand Down

0 comments on commit 066b997

Please sign in to comment.