Skip to content

Commit

Permalink
Shuffle based on both the global rank and worker id
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Nov 28, 2024
1 parent 9be9887 commit 1a1888b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
19 changes: 17 additions & 2 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

from qusi.internal.light_curve_collection import LightCurveObservationCollection
from qusi.internal.light_curve_collection import LightCurveObservationCollection


class OutOfBoundsInjectionHandlingMethod(Enum):
Expand Down Expand Up @@ -78,12 +78,27 @@ def __init__(
raise ValueError(error_message)
self.post_injection_transform: Callable[[Any], Any] = post_injection_transform
self.worker_randomizing_set: bool = False
self.global_rank: int | None = None
self.world_size: int | None = None
self.worker_id: int | None = None
self.number_of_workers: int | None = None
self.global_worker_rank: int | None = None

def __iter__(self):
# TODO: This is a hack and there are certainly better places to put this.
if not self.worker_randomizing_set:
worker_info = torch.utils.data.get_worker_info()
if self.global_rank is None:
self.global_rank = 0
self.world_size = 1
if worker_info is not None:
self.seed_random(worker_info.id)
self.worker_id = worker_info.id
self.number_of_workers = worker_info.num_workers
else:
self.worker_id = 0
self.number_of_workers = 0
self.global_worker_rank = (self.global_rank * self.number_of_workers) + self.worker_id
self.seed_random(self.worker_id)
self.worker_randomizing_set = True
base_light_curve_collection_iter_and_type_pairs: list[
tuple[Iterator[Path], Callable[[Path], LightCurveObservation], LightCurveCollectionType]
Expand Down
5 changes: 4 additions & 1 deletion src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datetime
import logging
import math
from pathlib import Path
from warnings import warn

Expand Down Expand Up @@ -91,6 +90,10 @@ def train_session(
accelerator=system_configuration.accelerator,
logger=loggers,
)
# TODO: Not a fan of needing to magically pass the process number to the datasets here.
for train_dataset in train_datasets:
train_dataset.global_rank = trainer.global_rank
train_dataset.world_size = trainer.world_size

train_dataset = InterleavedDataset.new(*train_datasets)
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process
Expand Down

0 comments on commit 1a1888b

Please sign in to comment.