From 1a1888bc0b054eb6940eacd576fcd26253ecf785 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 28 Nov 2024 14:48:00 -0500 Subject: [PATCH] Shuffle based on both the global rank and worker id --- src/qusi/internal/light_curve_dataset.py | 19 +++++++++++++++++-- src/qusi/internal/lightning_train_session.py | 5 ++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index ba03607..27c340d 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from qusi.internal.light_curve_collection import LightCurveObservationCollection +from qusi.internal.light_curve_collection import LightCurveObservationCollection class OutOfBoundsInjectionHandlingMethod(Enum): @@ -78,12 +78,27 @@ def __init__( raise ValueError(error_message) self.post_injection_transform: Callable[[Any], Any] = post_injection_transform self.worker_randomizing_set: bool = False + self.global_rank: int | None = None + self.world_size: int | None = None + self.worker_id: int | None = None + self.number_of_workers: int | None = None + self.global_worker_rank: int | None = None def __iter__(self): + # TODO: This is a hack and there are certainly better places to put this. if not self.worker_randomizing_set: worker_info = torch.utils.data.get_worker_info() + if self.global_rank is None: + self.global_rank = 0 + self.world_size = 1 if worker_info is not None: - self.seed_random(worker_info.id) + self.worker_id = worker_info.id + self.number_of_workers = worker_info.num_workers + else: + self.worker_id = 0 + self.number_of_workers = 0 + self.global_worker_rank = (self.global_rank * self.number_of_workers) + self.worker_id + self.seed_random(self.worker_id) self.worker_randomizing_set = True base_light_curve_collection_iter_and_type_pairs: list[ tuple[Iterator[Path], Callable[[Path], LightCurveObservation], LightCurveCollectionType] diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py index 9d0da4e..336718e 100644 --- a/src/qusi/internal/lightning_train_session.py +++ b/src/qusi/internal/lightning_train_session.py @@ -2,7 +2,6 @@ import datetime import logging -import math from pathlib import Path from warnings import warn @@ -91,6 +90,10 @@ def train_session( accelerator=system_configuration.accelerator, logger=loggers, ) + # TODO: Not a fan of needing to magically pass the process number to the datasets here. + for train_dataset in train_datasets: + train_dataset.global_rank = trainer.global_rank + train_dataset.world_size = trainer.world_size train_dataset = InterleavedDataset.new(*train_datasets) workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process