Skip to content

Commit

Permalink
Add metric passing
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Mar 26, 2024
1 parent d03aa10 commit 37cf34f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/qusi/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ def train_session(
train_datasets: list[LightCurveDataset],
validation_datasets: list[LightCurveDataset],
model: Module,
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
):
if hyperparameter_configuration is None:
hyperparameter_configuration = TrainHyperparameterConfiguration.new()
if logging_configuration is None:
logging_configuration = TrainLoggingConfiguration.new()
if loss_function is None:
loss_function = BCELoss()
if metric_functions is None:
metric_functions = [BinaryAccuracy()]
set_up_default_logger()
wandb_init(
process_rank=0,
Expand Down Expand Up @@ -76,8 +82,7 @@ def train_session(
else:
device = torch.device("cpu")
model = model.to(device, non_blocking=True)
loss_function = BCELoss().to(device, non_blocking=True)
metric_functions = [BinaryAccuracy()]
loss_function = loss_function.to(device, non_blocking=True)
optimizer = AdamW(model.parameters())
metric_functions: list[Module] = [
metric_function.to(device, non_blocking=True)
Expand Down

0 comments on commit 37cf34f

Please sign in to comment.