Skip to content

Commit

Permalink
add assert for seq model
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Feb 26, 2024
1 parent 6281907 commit e19c105
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(
else:
self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout)
self.num_hid = num_hid
# TODO: Merge non-autoregressive implementations of sequence generation
assert not cfg.model.do_separate_pb, "Not implemented for SeqTransformerGFN (since P_B=1 when autoregressive)."

def forward(self, xs: SeqBatch, cond, batched=False):
"""Returns a GraphActionCategorical and a tensor of state predictions.
Expand Down

0 comments on commit e19c105

Please sign in to comment.