diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index e0673430..da9e52d9 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -60,6 +60,8 @@ def __init__( else: self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout) self.num_hid = num_hid + # TODO: Merge non-autoregressive implementations of sequence generation + assert not cfg.model.do_separate_pb, "Not implemented for SeqTransformerGFN (since P_B=1 when autoregressive)." def forward(self, xs: SeqBatch, cond, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions.