Skip to content

Commit

Permalink
Switch back to torchmetrics
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 10, 2024
1 parent a1accd5 commit 98c8b8c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"torch>=2.0.1",
"torchvision>=0.15.2",
"polars>=0.19.10",
"torchmetrics>=1.2.0",
"stringcase>=1.2.0",
"atpublic>=4.0",
"pytest-pycharm>=0.7.0",
Expand Down
4 changes: 4 additions & 0 deletions src/qusi/internal/light_curve_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def light_curve_iter(self) -> Iterator[LightCurve]:
:return: The iterable of the light curves.
"""
light_curve_paths = self.path_getter.get_shuffled_paths()
if len(light_curve_paths) == 0:
raise ValueError('LightCurveCollection returned no paths.')
for light_curve_path in light_curve_paths:
times, fluxes = self.load_times_and_fluxes_from_path_function(
light_curve_path
Expand Down Expand Up @@ -264,6 +266,8 @@ def observation_iter(self) -> Iterator[LightCurveObservation]:
:return: The iterable of the light curves.
"""
light_curve_paths = self.path_getter.get_shuffled_paths()
if len(light_curve_paths) == 0:
raise ValueError('LightCurveObservationCollection returned no paths.')
for light_curve_path in light_curve_paths:
times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path(
light_curve_path
Expand Down
8 changes: 5 additions & 3 deletions src/qusi/internal/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from torch.nn import BCELoss, Module
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAccuracy, BinaryAUROC

import wandb
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset
from qusi.internal.logging import set_up_default_logger
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
Expand Down Expand Up @@ -51,14 +52,15 @@ def train_session(
if metric_functions is None:
metric_functions = [BinaryAccuracy(), BinaryAUROC()]
set_up_default_logger()
sessions_directory = Path("sessions")
sessions_directory.mkdir(exist_ok=True)
wandb_init(
process_rank=0,
project=logging_configuration.wandb_project,
entity=logging_configuration.wandb_entity,
settings=wandb.Settings(start_method="thread"),
dir=sessions_directory,
)
sessions_directory = Path("sessions")
sessions_directory.mkdir(exist_ok=True)
train_dataset = InterleavedDataset.new(*train_datasets)
torch.multiprocessing.set_start_method("spawn")
debug = False
Expand Down

0 comments on commit 98c8b8c

Please sign in to comment.