@@ -117,6 +117,61 @@ def _encode_inputs(
117
117
return (batch_size , questions , context , eos_user_tokens , bos_embeds ,
118
118
pad_embeds )
119
119
120
+ def _label_input_ids (self , label , eos_tokens ):
121
+ label_input_ids = label .input_ids [i ][:MAX_NEW_TOKENS ]
122
+ label_input_ids += eos_tokens .input_ids # Add EOS token.
123
+ return label_input_ids
124
+
125
+ def _input_ids (self , context , question , eos_user_tokens ):
126
+ input_ids : List [int ] = []
127
+ if context is not None :
128
+ input_ids += context .input_ids [i ][:MAX_TXT_LEN ]
129
+ input_ids += question .input_ids [i ]
130
+ input_ids += eos_user_tokens .input_ids
131
+ return input_ids
132
+
133
+ def _inputs_embeds (self , input_ids , bos_embeds , embedding = None ):
134
+ inputs_embeds = self .word_embedding (
135
+ torch .tensor (input_ids , device = self .llm_device ))
136
+
137
+ to_cat = [bos_embeds ]
138
+ if embedding is not None :
139
+ to_cat .append (embedding [i ])
140
+ to_cat .append (inputs_embeds )
141
+ inputs_embeds = torch .cat ([i .to (self .llm_device ) for i in to_cat ], dim = 0 )
142
+ return inputs_embeds
143
+
144
+ def append_embeds (self , inputs_embeds , batch_inputs_embeds , batch_attention_mask , label_input_ids = None , batch_label_input_ids = None ):
145
+ batch_inputs_embeds .append (inputs_embeds )
146
+ batch_attention_mask .append ([1 ] * inputs_embeds .size (0 ))
147
+ if label_input_ids is not None :
148
+ label_input_ids = [IGNORE_INDEX ] * (
149
+ inputs_embeds .size (0 ) - len (label_input_ids )) + label_input_ids
150
+ batch_label_input_ids .append (label_input_ids )
151
+ return batch_inputs_embeds , batch_attention_mask , batch_label_input_ids
152
+
153
+ def pad_embeds (batch_inputs_embeds , batch_attention_mask , batch_label_input_ids = None ):
154
+ max_length = max ([x .size (0 ) for x in batch_inputs_embeds ])
155
+ for i in range (batch_size ):
156
+ pad = max_length - batch_inputs_embeds [i ].size (0 )
157
+ batch_inputs_embeds [i ] = torch .cat ([
158
+ pad_embeds .repeat (pad , 1 ),
159
+ batch_inputs_embeds [i ],
160
+ ])
161
+ batch_attention_mask [i ] = [0 ] * pad + batch_attention_mask [i ]
162
+ if batch_label_input_ids is not None :
163
+ batch_label_input_ids [i ] = ([IGNORE_INDEX ] * pad +
164
+ batch_label_input_ids [i ])
165
+ inputs_embeds = torch .stack (batch_inputs_embeds , dim = 0 )
166
+ attention_mask = torch .tensor (batch_attention_mask ,
167
+ device = self .llm_device )
168
+ if batch_label_input_ids is not None :
169
+ label_input_ids = torch .tensor (batch_label_input_ids ,
170
+ device = self .llm_device )
171
+ else :
172
+ label_input_ids = None
173
+ return inputs_embeds , attention_mask , label_input_ids
174
+
120
175
def forward (
121
176
self ,
122
177
question : List [str ],
@@ -150,48 +205,15 @@ def forward(
150
205
batch_attention_mask = []
151
206
batch_label_input_ids = []
152
207
for i in range (batch_size ):
153
- label_input_ids = label .input_ids [i ][:MAX_NEW_TOKENS ]
154
- label_input_ids += eos_tokens .input_ids # Add EOS token.
155
-
156
- input_ids : List [int ] = []
157
- if context is not None :
158
- input_ids += context .input_ids [i ][:MAX_TXT_LEN ]
159
- input_ids += question .input_ids [i ]
160
- input_ids += eos_user_tokens .input_ids
208
+ label_input_ids = self ._label_input_ids (label , eos_tokens )
209
+ input_ids = self ._input_ids (context , question , eos_user_tokens )
161
210
input_ids += label_input_ids
162
211
163
- inputs_embeds = self .word_embedding (
164
- torch .tensor (input_ids , device = self .llm_device ))
165
-
166
- to_cat = [bos_embeds ]
167
- if embedding is not None :
168
- to_cat .append (embedding [i ])
169
- to_cat .append (inputs_embeds )
170
- inputs_embeds = torch .cat (to_cat , dim = 0 )
171
-
172
- batch_inputs_embeds .append (inputs_embeds )
173
- batch_attention_mask .append ([1 ] * inputs_embeds .size (0 ))
174
- label_input_ids = [IGNORE_INDEX ] * (
175
- inputs_embeds .size (0 ) - len (label_input_ids )) + label_input_ids
176
- batch_label_input_ids .append (label_input_ids )
212
+ inputs_embeds = self ._inputs_embeds (input_ids , bos_embeds , embedding )
177
213
178
- # Pad input embeddings:
179
- max_length = max ([x .size (0 ) for x in batch_inputs_embeds ])
180
- for i in range (batch_size ):
181
- pad = max_length - batch_inputs_embeds [i ].size (0 )
182
- batch_inputs_embeds [i ] = torch .cat ([
183
- pad_embeds .repeat (pad , 1 ),
184
- batch_inputs_embeds [i ],
185
- ])
186
- batch_attention_mask [i ] = [0 ] * pad + batch_attention_mask [i ]
187
- batch_label_input_ids [i ] = ([IGNORE_INDEX ] * pad +
188
- batch_label_input_ids [i ])
214
+ batch_inputs_embeds , batch_attention_mask , batch_label_input_ids = self .append_embeds (inputs_embeds , batch_inputs_embeds , batch_attention_mask , label_input_ids , batch_label_input_ids )
189
215
190
- inputs_embeds = torch .stack (batch_inputs_embeds , dim = 0 )
191
- attention_mask = torch .tensor (batch_attention_mask ,
192
- device = self .llm_device )
193
- label_input_ids = torch .tensor (batch_label_input_ids ,
194
- device = self .llm_device )
216
+ inputs_embeds , attention_mask , label_input_ids = self .pad_embeds (batch_inputs_embeds , batch_attention_mask , batch_label_input_ids )
195
217
196
218
with self .autocast_context :
197
219
outputs = self .llm (
@@ -235,37 +257,14 @@ def inference(
235
257
batch_attention_mask = []
236
258
for i in range (batch_size ):
237
259
input_ids : List [int ] = []
238
- if context is not None :
239
- input_ids = context .input_ids [i ][:MAX_TXT_LEN ]
240
- input_ids += question .input_ids [i ]
241
- input_ids += eos_user_tokens .input_ids
242
-
243
- inputs_embeds = self .word_embedding (
244
- torch .tensor (input_ids , device = self .llm_device ))
260
+ input_ids = self ._input_ids (context , question , eos_user_tokens )
245
261
246
- to_cat = [bos_embeds ]
247
- if embedding is not None :
248
- to_cat .append (embedding [i ])
249
- to_cat .append (inputs_embeds )
250
- inputs_embeds = torch .cat (to_cat , dim = 0 )
262
+ inputs_embeds = self ._inputs_embeds (input_ids , bos_embeds , embedding )
251
263
252
- batch_inputs_embeds .append (inputs_embeds )
253
- batch_attention_mask .append ([1 ] * inputs_embeds .size (0 ))
264
+ batch_inputs_embeds , batch_attention_mask , _ = self .append_embeds (inputs_embeds , batch_inputs_embeds , batch_attention_mask )
254
265
255
- # Pad input embeddings:
256
- max_length = max ([x .size (0 ) for x in batch_inputs_embeds ])
257
- for i in range (batch_size ):
258
- pad = max_length - batch_inputs_embeds [i ].size (0 )
259
- batch_inputs_embeds [i ] = torch .cat ([
260
- pad_embeds .repeat (pad , 1 ),
261
- batch_inputs_embeds [i ],
262
- ])
263
- batch_attention_mask [i ] = [0 ] * pad + batch_attention_mask [i ]
264
-
265
- inputs_embeds = torch .stack (batch_inputs_embeds , dim = 0 )
266
- attention_mask = torch .tensor (batch_attention_mask ,
267
- device = self .llm_device )
268
266
267
+ inputs_embeds , attention_mask , _ = self .pad_embeds (batch_inputs_embeds , batch_attention_mask )
269
268
bos_token = self .tokenizer (
270
269
BOS ,
271
270
add_special_tokens = False ,
0 commit comments