Skip to content

Commit 289ba8d

Browse files
committed
Fix bug where multiple dataloader workers were producing the same batches
1 parent 8da5e8a commit 289ba8d

File tree

4 files changed

+93
-20
lines changed

4 files changed

+93
-20
lines changed

src/qusi/internal/light_curve_collection.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __getitem__(self, indexes: int | tuple[int]) -> Path | tuple[Path]:
6161

6262

6363
class PathGetterBase(PathIterableBase, PathIndexableBase):
64-
pass
64+
random_number_generator: Random
6565

6666

6767
@dataclass
@@ -265,17 +265,26 @@ def observation_iter(self) -> Iterator[LightCurveObservation]:
265265
266266
:return: The iterable of the light curves.
267267
"""
268+
light_curve_paths = self.path_iter()
269+
for light_curve_path in light_curve_paths:
270+
light_curve_observation = self.observation_from_path(light_curve_path)
271+
yield light_curve_observation
272+
273+
def observation_from_path(self, light_curve_path: Path) -> LightCurveObservation:
274+
times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path(
275+
light_curve_path
276+
)
277+
label = self.load_label_from_path_function(light_curve_path)
278+
light_curve = LightCurve.new(times, fluxes)
279+
light_curve_observation = LightCurveObservation.new(light_curve, label)
280+
light_curve_observation.path = light_curve_path # TODO: Quick debug hack.
281+
return light_curve_observation
282+
283+
def path_iter(self) -> Iterable[Path]:
268284
light_curve_paths = self.path_getter.get_shuffled_paths()
269285
if len(light_curve_paths) == 0:
270286
raise ValueError('LightCurveObservationCollection returned no paths.')
271-
for light_curve_path in light_curve_paths:
272-
times, fluxes = self.light_curve_collection.load_times_and_fluxes_from_path(
273-
light_curve_path
274-
)
275-
label = self.load_label_from_path_function(light_curve_path)
276-
light_curve = LightCurve.new(times, fluxes)
277-
light_curve_observation = LightCurveObservation.new(light_curve, label)
278-
yield light_curve_observation
287+
return light_curve_paths
279288

280289
def __getitem__(self, index: int) -> LightCurveObservation:
281290
light_curve_path = self.path_getter[index]

src/qusi/internal/light_curve_dataset.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
import copy
4+
import itertools
45
import math
56
import re
67
import shutil
78
import socket
89
from enum import Enum
910
from functools import partial
1011
from pathlib import Path
12+
from random import Random
1113
from typing import TYPE_CHECKING, Any, Callable, TypeVar
1214

1315
import numpy as np
@@ -75,52 +77,64 @@ def __init__(
7577
)
7678
raise ValueError(error_message)
7779
self.post_injection_transform: Callable[[Any], Any] = post_injection_transform
80+
self.worker_randomizing_set: bool = False
7881

7982
def __iter__(self):
83+
if not self.worker_randomizing_set:
84+
worker_info = torch.utils.data.get_worker_info()
85+
if worker_info is not None:
86+
self.seed_random(worker_info.id)
87+
self.worker_randomizing_set = True
8088
base_light_curve_collection_iter_and_type_pairs: list[
81-
tuple[Iterator[LightCurveObservation], LightCurveCollectionType]
89+
tuple[Iterator[Path], Callable[[Path], LightCurveObservation], LightCurveCollectionType]
8290
] = []
8391
injectee_collections = copy.copy(self.injectee_light_curve_collections)
8492
for standard_collection in self.standard_light_curve_collections:
8593
if standard_collection in injectee_collections:
8694
base_light_curve_collection_iter_and_type_pairs.append(
8795
(
88-
loop_iter_function(standard_collection.observation_iter),
96+
loop_iter_function(standard_collection.path_iter),
97+
standard_collection.observation_from_path,
8998
LightCurveCollectionType.STANDARD_AND_INJECTEE,
9099
)
91100
)
92101
injectee_collections.remove(standard_collection)
93102
else:
94103
base_light_curve_collection_iter_and_type_pairs.append(
95104
(
96-
loop_iter_function(standard_collection.observation_iter),
105+
loop_iter_function(standard_collection.path_iter),
106+
standard_collection.observation_from_path,
97107
LightCurveCollectionType.STANDARD,
98108
)
99109
)
100110
for injectee_collection in injectee_collections:
101111
base_light_curve_collection_iter_and_type_pair = (
102-
loop_iter_function(injectee_collection.observation_iter),
112+
loop_iter_function(injectee_collection.path_iter),
113+
injectee_collection.observation_from_path,
103114
LightCurveCollectionType.INJECTEE,
104115
)
105116
base_light_curve_collection_iter_and_type_pairs.append(base_light_curve_collection_iter_and_type_pair)
106117
injectable_light_curve_collection_iters: list[
107-
Iterator[LightCurveObservation]
118+
tuple[Iterator[Path], Callable[[Path], LightCurveObservation]]
108119
] = []
109120
for injectable_collection in self.injectable_light_curve_collections:
110-
injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.observation_iter)
111-
injectable_light_curve_collection_iters.append(injectable_light_curve_collection_iter)
121+
injectable_light_curve_collection_iter = loop_iter_function(injectable_collection.path_iter)
122+
injectable_light_curve_collection_iters.append(
123+
(injectable_light_curve_collection_iter, injectable_collection.observation_from_path))
112124
while True:
113125
for (
114126
base_light_curve_collection_iter_and_type_pair
115127
) in base_light_curve_collection_iter_and_type_pairs:
116-
(base_collection_iter, collection_type) = base_light_curve_collection_iter_and_type_pair
128+
(base_collection_iter, observation_from_path_function,
129+
collection_type) = base_light_curve_collection_iter_and_type_pair
117130
if collection_type in [
118131
LightCurveCollectionType.STANDARD,
119132
LightCurveCollectionType.STANDARD_AND_INJECTEE,
120133
]:
121134
# TODO: Preprocessing step should be here. Or maybe that should all be on the light curve collection
122135
# as well? Or passed in somewhere else?
123-
standard_light_curve = next(base_collection_iter)
136+
standard_path = next(base_collection_iter)
137+
standard_light_curve = observation_from_path_function(standard_path)
124138
transformed_standard_light_curve = self.post_injection_transform(
125139
standard_light_curve
126140
)
@@ -129,10 +143,12 @@ def __iter__(self):
129143
LightCurveCollectionType.INJECTEE,
130144
LightCurveCollectionType.STANDARD_AND_INJECTEE,
131145
]:
132-
for (injectable_light_curve_collection_iter) in injectable_light_curve_collection_iters:
133-
injectable_light_curve = next(
146+
for (injectable_light_curve_collection_iter,
147+
injectable_observation_from_path_function) in injectable_light_curve_collection_iters:
148+
injectable_light_path = next(
134149
injectable_light_curve_collection_iter
135150
)
151+
injectable_light_curve = injectable_observation_from_path_function(injectable_light_path)
136152
injectee_light_curve = next(base_collection_iter)
137153
injected_light_curve = inject_light_curve(
138154
injectee_light_curve, injectable_light_curve
@@ -188,6 +204,12 @@ def new(
188204
)
189205
return instance
190206

207+
def seed_random(self, seed: int):
208+
for collection_group in [self.standard_light_curve_collections, self.injectee_light_curve_collections,
209+
self.injectable_light_curve_collections]:
210+
for collection in collection_group:
211+
collection.path_getter.random_number_generator = Random(seed)
212+
191213

192214
def inject_light_curve(
193215
injectee_observation: LightCurveObservation,

tests/integration_tests/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
import torch
6+
from torch.utils.data import DataLoader
7+
8+
from qusi.internal.light_curve_collection import LightCurveObservationCollection
9+
from qusi.internal.light_curve_dataset import LightCurveDataset
10+
from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \
11+
pair_array_to_tensor
12+
13+
14+
def get_paths() -> list[Path]:
15+
return [Path('1'), Path('2'), Path('3'), Path('4'), Path('5'), Path('6'), Path('7'), Path('8')]
16+
17+
def load_times_and_fluxes_from_path(path: Path) -> [npt.NDArray, npt.NDArray]:
18+
value = float(str(path))
19+
return np.array([value]), np.array([value])
20+
21+
def load_label_from_path_function(path: Path) -> int:
22+
value = int(str(path))
23+
return value * 10
24+
25+
def post_injection_transform(x):
26+
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
27+
x = pair_array_to_tensor(x)
28+
return x
29+
30+
31+
def test_light_curve_dataset_with_and_without_multiple_workers_gives_same_batch_order():
32+
light_curve_collection = LightCurveObservationCollection.new(
33+
get_paths_function=get_paths,
34+
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
35+
load_label_from_path_function=load_label_from_path_function)
36+
light_curve_dataset = LightCurveDataset.new(standard_light_curve_collections=[light_curve_collection],
37+
post_injection_transform=post_injection_transform)
38+
multi_worker_dataloader = DataLoader(light_curve_dataset, batch_size=4, num_workers=2, prefetch_factor=1)
39+
multi_worker_dataloader_iter = iter(multi_worker_dataloader)
40+
multi_worker_batch0 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0]
41+
multi_worker_batch1 = next(multi_worker_dataloader_iter)[0].numpy()[:, 0]
42+
assert not np.array_equal(multi_worker_batch0, multi_worker_batch1)

0 commit comments

Comments
 (0)