Skip to content

Commit be03b2f

Browse files
lazaratanbengioe
andauthored
fixed sequence setup to be masked & p(x) computable (#111)
* fixed sequence setup to be masked & p(x) computable * added distillation code for toy_seq --------- Co-authored-by: Emmanuel Bengio <emmanuel.bengio@recursionpharma.com>
1 parent d1982ec commit be03b2f

File tree

4 files changed

+327
-175
lines changed

4 files changed

+327
-175
lines changed

src/gflownet/envs/seq_building_env.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
# For typing's sake, we'll pretend that a sequence is a graph.
1818
class Seq(Graph):
19-
def __init__(self):
20-
self.seq: list[Any] = []
19+
def __init__(self, seq=None):
20+
self.seq: list[Any] = [] if seq is None else seq
2121

2222
def __repr__(self):
2323
return "".join(map(str, self.seq))
@@ -58,7 +58,8 @@ def reverse(self, g: Graph, ga: GraphAction):
5858

5959

6060
class SeqBatch:
61-
def __init__(self, seqs: List[torch.Tensor], pad: int):
61+
def __init__(self, seqs: List[torch.Tensor], pad: int, max_len: int = 2048):
62+
self.max_len = max_len + 1 # +1 for BOS
6263
self.seqs = seqs
6364
self.x = pad_sequence(seqs, batch_first=False, padding_value=pad)
6465
self.mask = self.x.eq(pad).T
@@ -69,6 +70,14 @@ def __init__(self, seqs: List[torch.Tensor], pad: int):
6970
# Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this
7071
# is the total number of timesteps.
7172
self.num_graphs = self.lens.sum().item()
73+
self.batch_stop_mask = torch.ones_like(self.logit_idx)[:, None]
74+
self.batch_append_mask = (
75+
torch.ones_like(self.logit_idx)
76+
if self.lens.max() < self.max_len
77+
else (self.logit_idx % self.max_len).lt(self.max_len - 1)
78+
)[:, None].float()
79+
self.tail_stop_mask = torch.ones((len(seqs), 1))
80+
self.tail_append_mask = (self.lens[:, None] < self.max_len).float()
7281

7382
def to(self, device):
7483
for name in dir(self):
@@ -84,7 +93,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
8493
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
8594
"""
8695

87-
def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
96+
def __init__(self, alphabet: Sequence[str], num_cond_dim=0, max_len=None):
8897
self.alphabet = alphabet
8998
self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode]
9099

@@ -93,6 +102,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
93102
self.pad_token = len(alphabet) + 1
94103
self.num_actions = len(alphabet) + 1 # Alphabet + Stop
95104
self.num_cond_dim = num_cond_dim
105+
self.max_len = max_len
96106

97107
def aidx_to_GraphAction(self, g: Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
98108
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
@@ -120,7 +130,7 @@ def graph_to_Data(self, g: Graph):
120130
return torch.tensor([self.bos_token] + s.seq, dtype=torch.long)
121131

122132
def collate(self, graphs: List[Data]):
123-
return SeqBatch(graphs, pad=self.pad_token)
133+
return SeqBatch(graphs, pad=self.pad_token, max_len=self.max_len)
124134

125135
def is_sane(self, g: Graph) -> bool:
126136
return True
@@ -131,3 +141,6 @@ def graph_to_mol(self, g: Graph):
131141

132142
def object_to_log_repr(self, g: Graph):
133143
return self.graph_to_mol(g)
144+
145+
def mol_to_graph(self, mol) -> Graph:
146+
return mol

src/gflownet/models/seq_transformer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext
99
from gflownet.envs.seq_building_env import SeqBatch
1010
from gflownet.models.config import SeqPosEnc
11+
from gflownet.models.graph_transformer import mlp
1112
from gflownet.envs.graph_building_env import GraphActionCategorical, GraphActionType
1213

14+
1315
class MLPWithDropout(nn.Module):
1416
def __init__(self, in_dim, out_dim, hidden_layers, dropout_prob, init_drop=False):
1517
super(MLPWithDropout, self).__init__()
@@ -62,7 +64,7 @@ def __init__(
6264
self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid)
6365
encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout)
6466
self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers)
65-
self.logZ = nn.Linear(env_ctx.num_cond_dim, 1)
67+
self.logZ = mlp(env_ctx.num_cond_dim, num_hid, 1, 2) #nn.Linear(env_ctx.num_cond_dim, 1)
6668
if self.use_cond:
6769
self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout)
6870
self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid)
@@ -109,6 +111,7 @@ def forward(self, xs: SeqBatch, cond, batched=False):
109111
state_preds = out[xs.logit_idx, 0:ns] # (proper_time, num_state_out)
110112
stop_logits = out[xs.logit_idx, ns : ns + 1] # (proper_time, 1)
111113
add_node_logits = out[xs.logit_idx, ns + 1 :] # (proper_time, nout - 1)
114+
masks = [xs.batch_stop_mask, xs.batch_append_mask]
112115
# `time` above is really max_time, whereas proper_time = sum(len(traj) for traj in xs))
113116
# which is what we need to give to GraphActionCategorical
114117
else:
@@ -119,18 +122,26 @@ def forward(self, xs: SeqBatch, cond, batched=False):
119122
state_preds = out[:, 0:ns]
120123
stop_logits = out[:, ns : ns + 1]
121124
add_node_logits = out[:, ns + 1 :]
125+
masks = [xs.tail_stop_mask, xs.tail_append_mask]
126+
127+
stop_logits = self._mask(stop_logits, masks[0])
128+
add_node_logits = self._mask(add_node_logits, masks[1])
122129

123130
return (
124131
GraphActionCategorical(
125132
xs,
126133
logits=[stop_logits, add_node_logits],
127134
keys=[None, None],
128135
types=self.ctx.action_type_order,
136+
masks=masks,
129137
slice_dict={},
130138
),
131139
state_preds,
132140
)
133141

142+
def _mask(self, logits, mask):
143+
return logits * mask + (1 - mask) * -1e6
144+
134145

135146
def generate_square_subsequent_mask(sz: int):
136147
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""

0 commit comments

Comments
 (0)