19
19
from gflownet .envs .graph_building_env import GraphActionCategorical , GraphBuildingEnv , GraphBuildingEnvContext
20
20
from gflownet .envs .seq_building_env import SeqBatch
21
21
from gflownet .utils .misc import create_logger
22
- from gflownet .utils .multiprocessing_proxy import mp_object_wrapper
22
+ from gflownet .utils .multiprocessing_proxy import mp_object_wrapper , BufferUnpickler
23
23
from gflownet .utils .misc import prepend_keys , average_values_across_dicts
24
24
from gflownet .utils .metrics_final_eval import compute_metrics
25
25
import wandb
@@ -219,7 +219,8 @@ def _wrap_for_mp(self, obj, send_to_device=False):
219
219
self .cfg .num_workers ,
220
220
cast_types = (gd .Batch , GraphActionCategorical , SeqBatch ),
221
221
pickle_messages = self .cfg .pickle_mp_messages ,
222
- )
222
+ sb_size = self .cfg .mp_buffer_size ,
223
+ ).placeholder
223
224
return placeholder , torch .device ("cpu" )
224
225
else :
225
226
return obj , self .device
@@ -248,6 +249,7 @@ def build_training_data_loader(self) -> DataLoader:
248
249
random_action_prob = self .cfg .algo .train_random_action_prob ,
249
250
random_traj_prob = self .cfg .algo .train_random_traj_prob ,
250
251
hindsight_ratio = self .cfg .replay .hindsight_ratio ,
252
+ mp_cfg = (self .cfg .num_workers , self .cfg .pickle_mp_messages , self .cfg .mp_buffer_size ),
251
253
)
252
254
for hook in self .sampling_hooks :
253
255
iterator .add_log_hook (hook )
@@ -258,7 +260,7 @@ def build_training_data_loader(self) -> DataLoader:
258
260
persistent_workers = self .cfg .num_workers > 0 ,
259
261
# The 2 here is an odd quirk of torch 1.10, it is fixed and
260
262
# replaced by None in torch 2.
261
- prefetch_factor = 1 if self .cfg .num_workers else 2 ,
263
+ prefetch_factor = 1 if self .cfg .num_workers else ( None if torch . __version__ . startswith ( '2' ) else 2 ) ,
262
264
generator = g ,
263
265
worker_init_fn = seed_worker
264
266
)
@@ -284,6 +286,7 @@ def build_validation_data_loader(self) -> DataLoader:
284
286
stream = False ,
285
287
random_action_prob = self .cfg .algo .valid_random_action_prob ,
286
288
is_validation = True ,
289
+ mp_cfg = (self .cfg .num_workers , self .cfg .pickle_mp_messages , self .cfg .mp_buffer_size ),
287
290
)
288
291
for hook in self .valid_sampling_hooks :
289
292
iterator .add_log_hook (hook )
@@ -292,7 +295,7 @@ def build_validation_data_loader(self) -> DataLoader:
292
295
batch_size = None ,
293
296
num_workers = self .cfg .num_workers ,
294
297
persistent_workers = self .cfg .num_workers > 0 ,
295
- prefetch_factor = 1 if self .cfg .num_workers else 2 ,
298
+ prefetch_factor = 1 if self .cfg .num_workers else ( None if torch . __version__ . startswith ( '2' ) else 2 ) ,
296
299
generator = g ,
297
300
worker_init_fn = seed_worker
298
301
)
@@ -322,6 +325,7 @@ def build_final_data_loader(self) -> DataLoader:
322
325
hindsight_ratio = 0.0 ,
323
326
is_validation = True ,
324
327
# init_train_iter=self.cfg.num_training_steps,
328
+ mp_cfg = (self .cfg .num_workers , self .cfg .pickle_mp_messages , self .cfg .mp_buffer_size ),
325
329
)
326
330
for hook in self .sampling_hooks :
327
331
iterator .add_log_hook (hook )
@@ -330,11 +334,19 @@ def build_final_data_loader(self) -> DataLoader:
330
334
batch_size = None ,
331
335
num_workers = self .cfg .num_workers ,
332
336
persistent_workers = self .cfg .num_workers > 0 ,
333
- prefetch_factor = 1 if self .cfg .num_workers else 2 ,
337
+ prefetch_factor = 1 if self .cfg .num_workers else ( None if torch . __version__ . startswith ( '2' ) else 2 ) ,
334
338
generator = g ,
335
339
worker_init_fn = seed_worker
336
340
)
337
341
342
+ def _maybe_resolve_shared_buffer (self , batch , dl : DataLoader ):
343
+ if dl .dataset .mp_buffer_size and isinstance (batch , (tuple , list )):
344
+ batch , wid = batch
345
+ batch = BufferUnpickler (dl .dataset .result_buffer [wid ], batch , self .device ).load ()
346
+ elif isinstance (batch , (gd .Batch , SeqBatch )):
347
+ batch = batch .to (self .device )
348
+ return batch
349
+
338
350
def train_batch (self , batch : gd .Batch , epoch_idx : int , batch_idx : int , train_it : int ) -> Dict [str , Any ]:
339
351
try :
340
352
loss , info = self .algo .compute_batch_losses (self .model , batch )
@@ -383,7 +395,8 @@ def run(self, logger=None):
383
395
start = self .cfg .start_at_step + 1
384
396
num_training_steps = self .cfg .num_training_steps
385
397
logger .info ("Starting training" )
386
- for it , (batch , _ ) in zip (range (start , 1 + num_training_steps ), cycle (train_dl )):
398
+ for it , batch in zip (range (start , 1 + num_training_steps ), cycle (train_dl )):
399
+ batch , _ = self ._maybe_resolve_shared_buffer (batch , train_dl )
387
400
epoch_idx = it // epoch_length
388
401
batch_idx = it % epoch_length
389
402
if self .replay_buffer is not None and len (self .replay_buffer ) < self .replay_buffer .warmup :
@@ -404,7 +417,8 @@ def run(self, logger=None):
404
417
candidates_eval_infos = []
405
418
# for batch in valid_dl:
406
419
# validate on at least 10 batches
407
- for valid_it , (batch , candidates_eval_info ) in zip (range (10 ), cycle (valid_dl )):
420
+ for valid_it , batch in zip (range (10 ), cycle (valid_dl )):
421
+ batch , candidates_eval_info = self ._maybe_resolve_shared_buffer (batch , valid_dl )
408
422
# print("valid_it", valid_it)
409
423
candidates_eval_infos .append (candidates_eval_info )
410
424
metrics = self .evaluate_batch (batch .to (self .device ), epoch_idx , batch_idx )
@@ -440,10 +454,11 @@ def run(self, logger=None):
440
454
if num_final_gen_steps :
441
455
gen_candidates_list = []
442
456
logger .info (f"Generating final { num_final_gen_steps } batches ..." )
443
- for it , ( _ , gen_candidates_eval_info ) in zip (
457
+ for it , batch in zip (
444
458
range (num_training_steps , num_training_steps + num_final_gen_steps + 1 ),
445
459
cycle (final_dl ),
446
460
):
461
+ _ , gen_candidates_eval_info = self ._maybe_resolve_shared_buffer (batch , final_dl )
447
462
gen_candidates_list .append (gen_candidates_eval_info )
448
463
449
464
info_final_gen = compute_metrics (gen_candidates_list , cand_type = self .task .cand_type , k = self .cfg .evaluation .k , reward_thresh = self .cfg .evaluation .reward_thresh , distance_thresh = self .cfg .evaluation .distance_thresh )
0 commit comments