From 4c53478fe82c001a35646d1f0ec81f9f844c2d36 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 8 May 2024 15:23:44 -0600 Subject: [PATCH] final fixes --- src/gflownet/models/seq_transformer.py | 9 ++++++++- src/gflownet/tasks/seh_frag_moo.py | 1 - src/gflownet/tasks/toy_seq.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index b1a4173a..84e604d1 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -1,5 +1,6 @@ # This code is adapted from https://github.com/MJ10/mo_gfn import math +from typing import Optional import torch import torch.nn as nn @@ -51,7 +52,7 @@ def __init__( self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid) encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout) self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers) - self.logZ = nn.Linear(env_ctx.num_cond_dim, 1) + self._logZ = nn.Linear(env_ctx.num_cond_dim, 1) if self.use_cond: self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout) self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid) @@ -59,6 +60,12 @@ def __init__( self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout) self.num_hid = num_hid + def logZ(self, cond_info: Optional[torch.Tensor]): + if cond_info is None: + return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) + return self._logZ(cond_info) + + def forward(self, xs: SeqBatch, cond, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 79c599e7..c4f0b67f 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -232,7 +232,6 @@ def setup_algo(self): def setup_task(self): self.cfg.cond.moo.num_objectives = len(self.cfg.task.seh_moo.objectives) self.task = SEHMOOTask( - dataset=self.training_data, cfg=self.cfg, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index a9481ca5..a215948b 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -45,6 +45,7 @@ def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False cfg.num_workers = 8 + cfg.num_validation_gen_steps = 1 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 cfg.opt.momentum = 0.9