Skip to content

Commit

Permalink
add a flag for predicting n
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Feb 10, 2024
1 parent 0760b87 commit 6dcce36
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class TBConfig:
Whether to correct for idempotent actions
do_parameterize_p_b : bool
Whether to parameterize the P_B distribution (otherwise it is uniform)
do_predict_n : bool
Whether to predict the number of paths in the graph
do_length_normalize : bool
Whether to normalize the loss by the length of the trajectory
subtb_max_len : int
Expand All @@ -45,6 +47,7 @@ class TBConfig:
variant: TBVariant = TBVariant.TB
do_correct_idempotent: bool = False
do_parameterize_p_b: bool = False
do_predict_n: bool = False
do_sample_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/online_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def setup_model(self):
self.ctx,
self.cfg,
do_bck=self.cfg.algo.tb.do_parameterize_p_b,
num_graph_out=self.cfg.algo.tb.do_predict_n + 1,
)

def setup_algo(self):
Expand Down

0 comments on commit 6dcce36

Please sign in to comment.