Skip to content

Commit

Permalink
Add multiple loggers
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Nov 1, 2024
1 parent 5f31595 commit fb7cdd7
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import copy
import itertools
import math
import re
import shutil
Expand Down
12 changes: 11 additions & 1 deletion src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import datetime
import logging
from pathlib import Path
from warnings import warn

import lightning
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.nn import BCELoss, Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -102,11 +105,18 @@ def train_session(

lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric,
logging_metrics=logging_metrics)
sessions_directory_path = Path(f'sessions')
session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}'
sessions_directory_path.mkdir(exist_ok=True, parents=True)
loggers = [
CSVLogger(save_dir=sessions_directory_path, name=session_name),
WandbLogger(save_dir=sessions_directory_path, name=session_name)]
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
log_every_n_steps=0,
accelerator=system_configuration.accelerator
accelerator=system_configuration.accelerator,
logger=loggers,
)
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)
17 changes: 9 additions & 8 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class QusiLightningModule(LightningModule):
def new(
cls,
model: Module,
optimizer: Optimizer | None,
optimizer: Optimizer | None = None,
loss_metric: Module | None = None,
logging_metrics: list[Module] | None = None,
) -> Self:
Expand Down Expand Up @@ -101,7 +101,7 @@ def on_train_epoch_start(self) -> None:
# Due to Lightning's inconsistent step ordering, performing this during the train epoch start gives the most
# consistent results.
self.cycle += 1
self.log(name='cycle', value=self.cycle)
self.log(name='cycle', value=self.cycle, reduce_fx=torch.max, rank_zero_only=True, on_step=False, on_epoch=True)

def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch, self.train_metric_group)
Expand All @@ -122,13 +122,17 @@ def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricG
return loss

def on_train_epoch_end(self) -> None:
self.log_loss_and_metrics(self.train_metric_group)
self.log_loss_and_metrics(self.train_metric_group, logging_name_prefix='')

def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: str = ''):
mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True, on_step=False,
on_epoch=True)
for state_based_logging_metric in metric_group.state_based_logging_metrics:
state_based_logging_metric_name = get_metric_name(state_based_logging_metric)
self.log(name=logging_name_prefix + state_based_logging_metric_name,
value=state_based_logging_metric.compute(), sync_dist=True)
value=state_based_logging_metric.compute(), sync_dist=True, on_step=False, on_epoch=True)
state_based_logging_metric.reset()
for functional_logging_metric_index, functional_logging_metric in enumerate(
metric_group.functional_logging_metrics):
Expand All @@ -139,10 +143,7 @@ def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: s
functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + functional_logging_metric_name,
value=functional_logging_metric_cycle_mean,
sync_dist=True)
mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
sync_dist=True, on_step=False, on_epoch=True)
metric_group.loss_cycle_total.zero_()
metric_group.steps_run_in_phase.zero_()
metric_group.functional_logging_metric_cycle_totals.zero_()
Expand Down
4 changes: 3 additions & 1 deletion src/qusi/internal/toy_light_curve_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def get_toy_dataset():
standard_light_curve_collections=[
get_toy_sine_wave_light_curve_observation_collection(),
get_toy_flat_light_curve_observation_collection(),
]
],
post_injection_transform=partial(default_light_curve_observation_post_injection_transform,
length=100, randomize=False)
)


Expand Down

0 comments on commit fb7cdd7

Please sign in to comment.