16
16
17
17
# For typing's sake, we'll pretend that a sequence is a graph.
18
18
class Seq (Graph ):
19
- def __init__ (self ):
20
- self .seq : list [Any ] = []
19
+ def __init__ (self , seq = None ):
20
+ self .seq : list [Any ] = [] if seq is None else seq
21
21
22
22
def __repr__ (self ):
23
23
return "" .join (map (str , self .seq ))
@@ -58,7 +58,8 @@ def reverse(self, g: Graph, ga: GraphAction):
58
58
59
59
60
60
class SeqBatch :
61
- def __init__ (self , seqs : List [torch .Tensor ], pad : int ):
61
+ def __init__ (self , seqs : List [torch .Tensor ], pad : int , max_len : int = 2048 ):
62
+ self .max_len = max_len + 1 # +1 for BOS
62
63
self .seqs = seqs
63
64
self .x = pad_sequence (seqs , batch_first = False , padding_value = pad )
64
65
self .mask = self .x .eq (pad ).T
@@ -69,6 +70,14 @@ def __init__(self, seqs: List[torch.Tensor], pad: int):
69
70
# Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this
70
71
# is the total number of timesteps.
71
72
self .num_graphs = self .lens .sum ().item ()
73
+ self .batch_stop_mask = torch .ones_like (self .logit_idx )[:, None ]
74
+ self .batch_append_mask = (
75
+ torch .ones_like (self .logit_idx )
76
+ if self .lens .max () < self .max_len
77
+ else (self .logit_idx % self .max_len ).lt (self .max_len - 1 )
78
+ )[:, None ].float ()
79
+ self .tail_stop_mask = torch .ones ((len (seqs ), 1 ))
80
+ self .tail_append_mask = (self .lens [:, None ] < self .max_len ).float ()
72
81
73
82
def to (self , device ):
74
83
for name in dir (self ):
@@ -84,7 +93,7 @@ class AutoregressiveSeqBuildingContext(GraphBuildingEnvContext):
84
93
This context gets an agent to generate sequences of tokens from left to right, i.e. in an autoregressive fashion.
85
94
"""
86
95
87
- def __init__ (self , alphabet : Sequence [str ], num_cond_dim = 0 ):
96
+ def __init__ (self , alphabet : Sequence [str ], num_cond_dim = 0 , max_len = None ):
88
97
self .alphabet = alphabet
89
98
self .action_type_order = [GraphActionType .Stop , GraphActionType .AddNode ]
90
99
@@ -93,6 +102,7 @@ def __init__(self, alphabet: Sequence[str], num_cond_dim=0):
93
102
self .pad_token = len (alphabet ) + 1
94
103
self .num_actions = len (alphabet ) + 1 # Alphabet + Stop
95
104
self .num_cond_dim = num_cond_dim
105
+ self .max_len = max_len
96
106
97
107
def aidx_to_GraphAction (self , g : Data , action_idx : Tuple [int , int , int ], fwd : bool = True ) -> GraphAction :
98
108
# Since there's only one "object" per timestep to act upon (in graph parlance), the row is always == 0
@@ -120,7 +130,7 @@ def graph_to_Data(self, g: Graph):
120
130
return torch .tensor ([self .bos_token ] + s .seq , dtype = torch .long )
121
131
122
132
def collate (self , graphs : List [Data ]):
123
- return SeqBatch (graphs , pad = self .pad_token )
133
+ return SeqBatch (graphs , pad = self .pad_token , max_len = self . max_len )
124
134
125
135
def is_sane (self , g : Graph ) -> bool :
126
136
return True
@@ -131,3 +141,6 @@ def graph_to_mol(self, g: Graph):
131
141
132
142
def object_to_log_repr (self , g : Graph ):
133
143
return self .graph_to_mol (g )
144
+
145
+ def mol_to_graph (self , mol ) -> Graph :
146
+ return mol
0 commit comments