From 6dcce365d04951e75f4b44da408fbfed39048e37 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 23:54:12 -0500 Subject: [PATCH] add a flag for predicting n --- src/gflownet/algo/config.py | 3 +++ src/gflownet/online_trainer.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 26839ae3..6184bdfc 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -29,6 +29,8 @@ class TBConfig: Whether to correct for idempotent actions do_parameterize_p_b : bool Whether to parameterize the P_B distribution (otherwise it is uniform) + do_predict_n : bool + Whether to predict the number of paths in the graph do_length_normalize : bool Whether to normalize the loss by the length of the trajectory subtb_max_len : int @@ -45,6 +47,7 @@ class TBConfig: variant: TBVariant = TBVariant.TB do_correct_idempotent: bool = False do_parameterize_p_b: bool = False + do_predict_n: bool = False do_sample_p_b: bool = False do_length_normalize: bool = False subtb_max_len: int = 128 diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index f36b8211..81fd7ae7 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -31,6 +31,7 @@ def setup_model(self): self.ctx, self.cfg, do_bck=self.cfg.algo.tb.do_parameterize_p_b, + num_graph_out=self.cfg.algo.tb.do_predict_n + 1, ) def setup_algo(self):