Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed May 8, 2024
1 parent aff465e commit 4c53478
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/gflownet/models/seq_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This code is adapted from https://github.com/MJ10/mo_gfn
import math
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -51,14 +52,20 @@ def __init__(
self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid)
encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout)
self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers)
self.logZ = nn.Linear(env_ctx.num_cond_dim, 1)
self._logZ = nn.Linear(env_ctx.num_cond_dim, 1)
if self.use_cond:
self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid)
else:
self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout)
self.num_hid = num_hid

def logZ(self, cond_info: Optional[torch.Tensor]):
if cond_info is None:
return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device))
return self._logZ(cond_info)


def forward(self, xs: SeqBatch, cond, batched=False):
"""Returns a GraphActionCategorical and a tensor of state predictions.
Expand Down
1 change: 0 additions & 1 deletion src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def setup_algo(self):
def setup_task(self):
self.cfg.cond.moo.num_objectives = len(self.cfg.task.seh_moo.objectives)
self.task = SEHMOOTask(
dataset=self.training_data,
cfg=self.cfg,
wrap_model=self._wrap_for_mp,
)
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/tasks/toy_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def set_default_hps(self, cfg: Config):
cfg.hostname = socket.gethostname()
cfg.pickle_mp_messages = False
cfg.num_workers = 8
cfg.num_validation_gen_steps = 1
cfg.opt.learning_rate = 1e-4
cfg.opt.weight_decay = 1e-8
cfg.opt.momentum = 0.9
Expand Down

0 comments on commit 4c53478

Please sign in to comment.