Skip to content

Commit ab6c831

Browse files
committed
partial merge
1 parent 0350235 commit ab6c831

File tree

5 files changed

+217
-40
lines changed

5 files changed

+217
-40
lines changed

src/gflownet/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class Config:
112112
num_workers: int = 0
113113
hostname: Optional[str] = None
114114
pickle_mp_messages: bool = False
115+
mp_buffer_size: Optional[int] = 32 * 1024 ** 2 # 32Mb
115116
git_hash: Optional[str] = None
116117
overwrite_existing_exp: bool = True
117118
algo: AlgoConfig = AlgoConfig()

src/gflownet/data/sampling_iterator.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from gflownet.data.replay_buffer import ReplayBuffer
1414
from gflownet.envs.graph_building_env import GraphActionCategorical
15+
from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer
1516

1617

1718
class SamplingIterator(IterableDataset):
@@ -44,7 +45,8 @@ def __init__(
4445
random_traj_prob: float = 0.0,
4546
hindsight_ratio: float = 0.0,
4647
init_train_iter: int = 0,
47-
is_validation: bool = False
48+
is_validation: bool = False,
49+
mp_cfg = None,
4850
):
4951
"""Parameters
5052
----------
@@ -110,6 +112,7 @@ def __init__(
110112
self.train_it = init_train_iter
111113
self.is_validation = is_validation
112114
self.do_validate_batch = False # Turn this on for debugging
115+
self.num_workers, _, self.mp_buffer_size = mp_cfg
113116

114117
# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
115118
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
@@ -125,6 +128,8 @@ def __init__(
125128
self.log = SQLiteLog()
126129
self.log_hooks: List[Callable] = []
127130

131+
self.setup_mp_buffers()
132+
128133
def add_log_hook(self, hook: Callable):
129134
self.log_hooks.append(hook)
130135

@@ -282,17 +287,21 @@ def __iter__(self):
282287
# and sample replay_batch_size of them to add to the batch
283288

284289
# cond_info is a dict, so we need to convert it to a list of dicts
285-
cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)]
290+
cond_info_ = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)]
286291

287292
# push the online trajectories in the replay buffer and sample a new 'online' batch
288293
for i in range(num_offline, len(trajs)):
294+
if not is_valid[i].item():
295+
continue
289296
self.replay_buffer.push(
290297
deepcopy(trajs[i]),
291298
deepcopy(log_rewards[i]),
292299
deepcopy(flat_rewards[i]),
293-
deepcopy(cond_info[i]),
300+
deepcopy(cond_info_[i]),
294301
deepcopy(is_valid[i]),
295302
)
303+
if self.replay_buffer is not None and len(self.replay_buffer) > self.replay_buffer.warmup:
304+
cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)]
296305
replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample(
297306
self.replay_batch_size
298307
)
@@ -340,11 +349,11 @@ def __iter__(self):
340349

341350
# TODO: need to change this for non-molecule environments
342351
try:
343-
smiles = [Chem.MolToSmiles(self.ctx.graph_to_mol(traj["result"])) for traj in trajs]
352+
smiles = [self.ctx.object_to_log_repr(traj["result"]) for traj in trajs]
344353
except:
345354
smiles = [traj["result"].__repr__() for traj in trajs]
346355
# alternative: [traj["smi"] for traj in trajs]
347-
yield batch, (smiles, flat_rewards)
356+
yield self._maybe_put_in_mp_buffer((batch, (smiles, flat_rewards)))
348357

349358
def validate_batch(self, batch, trajs):
350359
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
@@ -400,6 +409,18 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info):
400409

401410
self.log.insert_many(data, data_labels)
402411

412+
def setup_mp_buffers(self):
413+
if self.num_workers > 0 and self.mp_buffer_size:
414+
self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.num_workers)]
415+
else:
416+
self.mp_buffer_size = None
417+
418+
def _maybe_put_in_mp_buffer(self, batch):
419+
if self.mp_buffer_size:
420+
return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid)
421+
else:
422+
return batch
423+
403424

404425
class SQLiteLog:
405426
def __init__(self, timeout=300):

src/gflownet/tasks/seh_frag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class SEHFragTrainer(StandardOnlineTrainer):
8080
def set_default_hps(self, cfg: Config):
8181
cfg.hostname = socket.gethostname()
8282
cfg.pickle_mp_messages = False
83+
cfg.mp_buffer_size = 32 * 1024 ** 2 # 32Mb should be enough for this setup
8384
cfg.num_workers = 5
8485

8586
cfg.opt.learning_rate = 1e-4

src/gflownet/trainer.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext
2020
from gflownet.envs.seq_building_env import SeqBatch
2121
from gflownet.utils.misc import create_logger
22-
from gflownet.utils.multiprocessing_proxy import mp_object_wrapper
22+
from gflownet.utils.multiprocessing_proxy import mp_object_wrapper, BufferUnpickler
2323
from gflownet.utils.misc import prepend_keys, average_values_across_dicts
2424
from gflownet.utils.metrics_final_eval import compute_metrics
2525
import wandb
@@ -219,7 +219,8 @@ def _wrap_for_mp(self, obj, send_to_device=False):
219219
self.cfg.num_workers,
220220
cast_types=(gd.Batch, GraphActionCategorical, SeqBatch),
221221
pickle_messages=self.cfg.pickle_mp_messages,
222-
)
222+
sb_size=self.cfg.mp_buffer_size,
223+
).placeholder
223224
return placeholder, torch.device("cpu")
224225
else:
225226
return obj, self.device
@@ -248,6 +249,7 @@ def build_training_data_loader(self) -> DataLoader:
248249
random_action_prob=self.cfg.algo.train_random_action_prob,
249250
random_traj_prob=self.cfg.algo.train_random_traj_prob,
250251
hindsight_ratio=self.cfg.replay.hindsight_ratio,
252+
mp_cfg=(self.cfg.num_workers, self.cfg.pickle_mp_messages, self.cfg.mp_buffer_size),
251253
)
252254
for hook in self.sampling_hooks:
253255
iterator.add_log_hook(hook)
@@ -258,7 +260,7 @@ def build_training_data_loader(self) -> DataLoader:
258260
persistent_workers=self.cfg.num_workers > 0,
259261
# The 2 here is an odd quirk of torch 1.10, it is fixed and
260262
# replaced by None in torch 2.
261-
prefetch_factor=1 if self.cfg.num_workers else 2,
263+
prefetch_factor=1 if self.cfg.num_workers else (None if torch.__version__.startswith('2') else 2),
262264
generator=g,
263265
worker_init_fn=seed_worker
264266
)
@@ -284,6 +286,7 @@ def build_validation_data_loader(self) -> DataLoader:
284286
stream=False,
285287
random_action_prob=self.cfg.algo.valid_random_action_prob,
286288
is_validation=True,
289+
mp_cfg=(self.cfg.num_workers, self.cfg.pickle_mp_messages, self.cfg.mp_buffer_size),
287290
)
288291
for hook in self.valid_sampling_hooks:
289292
iterator.add_log_hook(hook)
@@ -292,7 +295,7 @@ def build_validation_data_loader(self) -> DataLoader:
292295
batch_size=None,
293296
num_workers=self.cfg.num_workers,
294297
persistent_workers=self.cfg.num_workers > 0,
295-
prefetch_factor=1 if self.cfg.num_workers else 2,
298+
prefetch_factor=1 if self.cfg.num_workers else (None if torch.__version__.startswith('2') else 2),
296299
generator=g,
297300
worker_init_fn=seed_worker
298301
)
@@ -322,6 +325,7 @@ def build_final_data_loader(self) -> DataLoader:
322325
hindsight_ratio=0.0,
323326
is_validation=True,
324327
# init_train_iter=self.cfg.num_training_steps,
328+
mp_cfg=(self.cfg.num_workers, self.cfg.pickle_mp_messages, self.cfg.mp_buffer_size),
325329
)
326330
for hook in self.sampling_hooks:
327331
iterator.add_log_hook(hook)
@@ -330,11 +334,19 @@ def build_final_data_loader(self) -> DataLoader:
330334
batch_size=None,
331335
num_workers=self.cfg.num_workers,
332336
persistent_workers=self.cfg.num_workers > 0,
333-
prefetch_factor=1 if self.cfg.num_workers else 2,
337+
prefetch_factor=1 if self.cfg.num_workers else (None if torch.__version__.startswith('2') else 2),
334338
generator=g,
335339
worker_init_fn=seed_worker
336340
)
337341

342+
def _maybe_resolve_shared_buffer(self, batch, dl: DataLoader):
343+
if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)):
344+
batch, wid = batch
345+
batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load()
346+
elif isinstance(batch, (gd.Batch, SeqBatch)):
347+
batch = batch.to(self.device)
348+
return batch
349+
338350
def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]:
339351
try:
340352
loss, info = self.algo.compute_batch_losses(self.model, batch)
@@ -383,7 +395,8 @@ def run(self, logger=None):
383395
start = self.cfg.start_at_step + 1
384396
num_training_steps = self.cfg.num_training_steps
385397
logger.info("Starting training")
386-
for it, (batch, _) in zip(range(start, 1 + num_training_steps), cycle(train_dl)):
398+
for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)):
399+
batch, _ = self._maybe_resolve_shared_buffer(batch, train_dl)
387400
epoch_idx = it // epoch_length
388401
batch_idx = it % epoch_length
389402
if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup:
@@ -404,7 +417,8 @@ def run(self, logger=None):
404417
candidates_eval_infos = []
405418
# for batch in valid_dl:
406419
# validate on at least 10 batches
407-
for valid_it, (batch, candidates_eval_info) in zip(range(10), cycle(valid_dl)):
420+
for valid_it, batch in zip(range(10), cycle(valid_dl)):
421+
batch, candidates_eval_info = self._maybe_resolve_shared_buffer(batch, valid_dl)
408422
# print("valid_it", valid_it)
409423
candidates_eval_infos.append(candidates_eval_info)
410424
metrics = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx)
@@ -440,10 +454,11 @@ def run(self, logger=None):
440454
if num_final_gen_steps:
441455
gen_candidates_list = []
442456
logger.info(f"Generating final {num_final_gen_steps} batches ...")
443-
for it, (_, gen_candidates_eval_info) in zip(
457+
for it, batch in zip(
444458
range(num_training_steps, num_training_steps + num_final_gen_steps + 1),
445459
cycle(final_dl),
446460
):
461+
_, gen_candidates_eval_info = self._maybe_resolve_shared_buffer(batch, final_dl)
447462
gen_candidates_list.append(gen_candidates_eval_info)
448463

449464
info_final_gen = compute_metrics(gen_candidates_list, cand_type=self.task.cand_type, k=self.cfg.evaluation.k, reward_thresh=self.cfg.evaluation.reward_thresh, distance_thresh=self.cfg.evaluation.distance_thresh)

0 commit comments

Comments
 (0)