1
1
from __future__ import annotations
2
2
3
3
import copy
4
+ import itertools
4
5
import math
5
6
import re
6
7
import shutil
7
8
import socket
8
9
from enum import Enum
9
10
from functools import partial
10
11
from pathlib import Path
12
+ from random import Random
11
13
from typing import TYPE_CHECKING , Any , Callable , TypeVar
12
14
13
15
import numpy as np
@@ -75,52 +77,64 @@ def __init__(
75
77
)
76
78
raise ValueError (error_message )
77
79
self .post_injection_transform : Callable [[Any ], Any ] = post_injection_transform
80
+ self .worker_randomizing_set : bool = False
78
81
79
82
def __iter__ (self ):
83
+ if not self .worker_randomizing_set :
84
+ worker_info = torch .utils .data .get_worker_info ()
85
+ if worker_info is not None :
86
+ self .seed_random (worker_info .id )
87
+ self .worker_randomizing_set = True
80
88
base_light_curve_collection_iter_and_type_pairs : list [
81
- tuple [Iterator [LightCurveObservation ], LightCurveCollectionType ]
89
+ tuple [Iterator [Path ], Callable [[ Path ], LightCurveObservation ], LightCurveCollectionType ]
82
90
] = []
83
91
injectee_collections = copy .copy (self .injectee_light_curve_collections )
84
92
for standard_collection in self .standard_light_curve_collections :
85
93
if standard_collection in injectee_collections :
86
94
base_light_curve_collection_iter_and_type_pairs .append (
87
95
(
88
- loop_iter_function (standard_collection .observation_iter ),
96
+ loop_iter_function (standard_collection .path_iter ),
97
+ standard_collection .observation_from_path ,
89
98
LightCurveCollectionType .STANDARD_AND_INJECTEE ,
90
99
)
91
100
)
92
101
injectee_collections .remove (standard_collection )
93
102
else :
94
103
base_light_curve_collection_iter_and_type_pairs .append (
95
104
(
96
- loop_iter_function (standard_collection .observation_iter ),
105
+ loop_iter_function (standard_collection .path_iter ),
106
+ standard_collection .observation_from_path ,
97
107
LightCurveCollectionType .STANDARD ,
98
108
)
99
109
)
100
110
for injectee_collection in injectee_collections :
101
111
base_light_curve_collection_iter_and_type_pair = (
102
- loop_iter_function (injectee_collection .observation_iter ),
112
+ loop_iter_function (injectee_collection .path_iter ),
113
+ injectee_collection .observation_from_path ,
103
114
LightCurveCollectionType .INJECTEE ,
104
115
)
105
116
base_light_curve_collection_iter_and_type_pairs .append (base_light_curve_collection_iter_and_type_pair )
106
117
injectable_light_curve_collection_iters : list [
107
- Iterator [LightCurveObservation ]
118
+ tuple [ Iterator [Path ], Callable [[ Path ], LightCurveObservation ] ]
108
119
] = []
109
120
for injectable_collection in self .injectable_light_curve_collections :
110
- injectable_light_curve_collection_iter = loop_iter_function (injectable_collection .observation_iter )
111
- injectable_light_curve_collection_iters .append (injectable_light_curve_collection_iter )
121
+ injectable_light_curve_collection_iter = loop_iter_function (injectable_collection .path_iter )
122
+ injectable_light_curve_collection_iters .append (
123
+ (injectable_light_curve_collection_iter , injectable_collection .observation_from_path ))
112
124
while True :
113
125
for (
114
126
base_light_curve_collection_iter_and_type_pair
115
127
) in base_light_curve_collection_iter_and_type_pairs :
116
- (base_collection_iter , collection_type ) = base_light_curve_collection_iter_and_type_pair
128
+ (base_collection_iter , observation_from_path_function ,
129
+ collection_type ) = base_light_curve_collection_iter_and_type_pair
117
130
if collection_type in [
118
131
LightCurveCollectionType .STANDARD ,
119
132
LightCurveCollectionType .STANDARD_AND_INJECTEE ,
120
133
]:
121
134
# TODO: Preprocessing step should be here. Or maybe that should all be on the light curve collection
122
135
# as well? Or passed in somewhere else?
123
- standard_light_curve = next (base_collection_iter )
136
+ standard_path = next (base_collection_iter )
137
+ standard_light_curve = observation_from_path_function (standard_path )
124
138
transformed_standard_light_curve = self .post_injection_transform (
125
139
standard_light_curve
126
140
)
@@ -129,10 +143,12 @@ def __iter__(self):
129
143
LightCurveCollectionType .INJECTEE ,
130
144
LightCurveCollectionType .STANDARD_AND_INJECTEE ,
131
145
]:
132
- for (injectable_light_curve_collection_iter ) in injectable_light_curve_collection_iters :
133
- injectable_light_curve = next (
146
+ for (injectable_light_curve_collection_iter ,
147
+ injectable_observation_from_path_function ) in injectable_light_curve_collection_iters :
148
+ injectable_light_path = next (
134
149
injectable_light_curve_collection_iter
135
150
)
151
+ injectable_light_curve = injectable_observation_from_path_function (injectable_light_path )
136
152
injectee_light_curve = next (base_collection_iter )
137
153
injected_light_curve = inject_light_curve (
138
154
injectee_light_curve , injectable_light_curve
@@ -188,6 +204,12 @@ def new(
188
204
)
189
205
return instance
190
206
207
+ def seed_random (self , seed : int ):
208
+ for collection_group in [self .standard_light_curve_collections , self .injectee_light_curve_collections ,
209
+ self .injectable_light_curve_collections ]:
210
+ for collection in collection_group :
211
+ collection .path_getter .random_number_generator = Random (seed )
212
+
191
213
192
214
def inject_light_curve (
193
215
injectee_observation : LightCurveObservation ,
0 commit comments